diff --git a/.gitignore b/.gitignore index a290ab7d5..523eb404d 100644 --- a/.gitignore +++ b/.gitignore @@ -50,6 +50,7 @@ coverage.xml .hypothesis/ .pytest_cache/ cover/ +tests/async/output/ # Translations *.mo @@ -257,4 +258,7 @@ continue_config.json .private/ CLAUDE_MONITOR.md -CLAUDE.md \ No newline at end of file +CLAUDE.md + +# Test output +logs/ diff --git a/crawl4ai/__init__.py b/crawl4ai/__init__.py index 0ab808f3f..ae18c1f8b 100644 --- a/crawl4ai/__init__.py +++ b/crawl4ai/__init__.py @@ -1,5 +1,6 @@ # __init__.py import warnings +from logging import Logger from .async_webcrawler import AsyncWebCrawler, CacheMode from .async_configs import BrowserConfig, CrawlerRunConfig, HTTPCrawlerConfig, LLMConfig @@ -64,6 +65,7 @@ DFSDeepCrawlStrategy, DeepCrawlDecorator, ) +from .deep_crawling.scorers import ScoringStats __all__ = [ "AsyncLoggerBase", @@ -121,6 +123,8 @@ "Crawl4aiDockerClient", "ProxyRotationStrategy", "RoundRobinProxyStrategy", + "ScoringStats", + "Logger", # Required for serialization ] diff --git a/crawl4ai/async_configs.py b/crawl4ai/async_configs.py index c7f9e739a..4df706cdf 100644 --- a/crawl4ai/async_configs.py +++ b/crawl4ai/async_configs.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os from .config import ( DEFAULT_PROVIDER, @@ -12,7 +14,7 @@ ) from .user_agent_generator import UAGen, ValidUAGenerator # , OnlineUAGenerator -from .extraction_strategy import ExtractionStrategy, LLMExtractionStrategy +from .extraction_strategy import ExtractionStrategy from .chunking_strategy import ChunkingStrategy, RegexChunking from .markdown_generation_strategy import MarkdownGenerationStrategy @@ -20,21 +22,23 @@ from .deep_crawling import DeepCrawlStrategy from .cache_context import CacheMode -from .proxy_strategy import ProxyRotationStrategy +from .proxy_strategy import ProxyRotationStrategy, ProxyConfig from typing import Union, List import inspect from typing import Any, Dict, Optional from enum import Enum +from pathlib import Path -from .proxy_strategy import ProxyConfig try: from .browser.docker_config import DockerConfig except ImportError: DockerConfig = None +Serialisable = Optional[Union[str, int, float, bool, List, Dict]] + -def to_serializable_dict(obj: Any, ignore_default_value : bool = False) -> Dict: +def to_serializable_dict(obj: Any, ignore_default_value: bool = False) -> Serialisable: """ Recursively convert an object to a serializable dictionary using {type, params} structure for complex objects. @@ -108,7 +112,7 @@ def to_serializable_dict(obj: Any, ignore_default_value : bool = False) -> Dict: return str(obj) -def from_serializable_dict(data: Any) -> Any: +def from_serializable(data: Serialisable) -> Any: """ Recursively convert a serializable dictionary back to an object instance. """ @@ -123,7 +127,7 @@ def from_serializable_dict(data: Any) -> Any: if isinstance(data, dict) and "type" in data: # Handle plain dictionaries if data["type"] == "dict": - return {k: from_serializable_dict(v) for k, v in data["value"].items()} + return {k: from_serializable(v) for k, v in data["value"].items()} # Import from crawl4ai for class instances import crawl4ai @@ -135,18 +139,16 @@ def from_serializable_dict(data: Any) -> Any: return cls(data["params"]) # Handle class instances - constructor_args = { - k: from_serializable_dict(v) for k, v in data["params"].items() - } + constructor_args = {k: from_serializable(v) for k, v in data["params"].items()} return cls(**constructor_args) # Handle lists if isinstance(data, list): - return [from_serializable_dict(item) for item in data] + return [from_serializable(item) for item in data] # Handle raw dictionaries (legacy support) if isinstance(data, dict): - return {k: from_serializable_dict(v) for k, v in data.items()} + return {k: from_serializable(v) for k, v in data.items()} return data @@ -205,7 +207,7 @@ class BrowserConfig: Default: True. accept_downloads (bool): Whether to allow file downloads. If True, requires a downloads_path. Default: False. - downloads_path (str or None): Directory to store downloaded files. If None and accept_downloads is True, + downloads_path (Path or str or None): Directory to store downloaded files. If None and accept_downloads is True, a default path will be created. Default: None. storage_state (str or dict or None): An in-memory storage state (cookies, localStorage). Default: None. @@ -235,26 +237,26 @@ def __init__( headless: bool = True, browser_mode: str = "dedicated", use_managed_browser: bool = False, - cdp_url: str = None, + cdp_url: Optional[str] = None, use_persistent_context: bool = False, - user_data_dir: str = None, + user_data_dir: Optional[str] = None, chrome_channel: str = "chromium", channel: str = "chromium", - proxy: str = None, + proxy: Optional[str] = None, proxy_config: Union[ProxyConfig, dict, None] = None, docker_config: Union["DockerConfig", dict, None] = None, viewport_width: int = 1080, viewport_height: int = 600, - viewport: dict = None, + viewport: Optional[dict] = None, accept_downloads: bool = False, - downloads_path: str = None, + downloads_path: Optional[Union[Path, str]] = None, storage_state: Union[str, dict, None] = None, ignore_https_errors: bool = True, java_script_enabled: bool = True, sleep_on_close: bool = False, verbose: bool = True, - cookies: list = None, - headers: dict = None, + cookies: Optional[list] = None, + headers: Optional[dict] = None, user_agent: str = ( # "Mozilla/5.0 (Macintosh; Intel Mac OS X 10.15; rv:109.0) AppleWebKit/537.36 " # "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 " @@ -265,7 +267,7 @@ def __init__( user_agent_generator_config: dict = {}, text_mode: bool = False, light_mode: bool = False, - extra_args: list = None, + extra_args: Optional[list] = None, debugging_port: int = 9222, host: str = "localhost", ): @@ -373,8 +375,8 @@ def from_kwargs(kwargs: dict) -> "BrowserConfig": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) " "AppleWebKit/537.36 (KHTML, like Gecko) Chrome/116.0.0.0 Safari/537.36", ), - user_agent_mode=kwargs.get("user_agent_mode"), - user_agent_generator_config=kwargs.get("user_agent_generator_config"), + user_agent_mode=kwargs.get("user_agent_mode", ""), + user_agent_generator_config=kwargs.get("user_agent_generator_config", ""), text_mode=kwargs.get("text_mode", False), light_mode=kwargs.get("light_mode", False), extra_args=kwargs.get("extra_args", []), @@ -438,15 +440,18 @@ def clone(self, **kwargs): config_dict.update(kwargs) return BrowserConfig.from_kwargs(config_dict) - # Create a funciton returns dict of the object - def dump(self) -> dict: + # Create a function returns dict of the object + def dump(self) -> Serialisable: # Serialize the object to a dictionary return to_serializable_dict(self) @staticmethod - def load(data: dict) -> "BrowserConfig": + def load(data: Serialisable) -> "BrowserConfig": + if data is None: + return BrowserConfig() + # Deserialize the object from a dictionary - config = from_serializable_dict(data) + config = from_serializable(data) if isinstance(config, BrowserConfig): return config return BrowserConfig.from_kwargs(config) @@ -512,17 +517,18 @@ def clone(self, **kwargs): config_dict.update(kwargs) return HTTPCrawlerConfig.from_kwargs(config_dict) - def dump(self) -> dict: + def dump(self) -> Serialisable: return to_serializable_dict(self) @staticmethod def load(data: dict) -> "HTTPCrawlerConfig": - config = from_serializable_dict(data) + config = from_serializable(data) if isinstance(config, HTTPCrawlerConfig): return config return HTTPCrawlerConfig.from_kwargs(config) -class CrawlerRunConfig(): + +class CrawlerRunConfig: _UNWANTED_PROPS = { 'disable_cache' : 'Instead, use cache_mode=CacheMode.DISABLED', 'bypass_cache' : 'Instead, use cache_mode=CacheMode.BYPASS', @@ -709,50 +715,50 @@ class CrawlerRunConfig(): into the main parameter set. Default: None. - url: str = None # This is not a compulsory parameter + url (str or None): This is not a compulsory parameter """ def __init__( self, # Content Processing Parameters word_count_threshold: int = MIN_WORD_THRESHOLD, - extraction_strategy: ExtractionStrategy = None, + extraction_strategy: Optional[ExtractionStrategy] = None, chunking_strategy: ChunkingStrategy = RegexChunking(), - markdown_generator: MarkdownGenerationStrategy = None, + markdown_generator: Optional[MarkdownGenerationStrategy] = None, only_text: bool = False, - css_selector: str = None, - target_elements: List[str] = None, - excluded_tags: list = None, - excluded_selector: str = None, + css_selector: Optional[str] = None, + target_elements: Optional[List[str]] = None, + excluded_tags: Optional[list] = None, + excluded_selector: Optional[str] = None, keep_data_attributes: bool = False, - keep_attrs: list = None, + keep_attrs: Optional[list] = None, remove_forms: bool = False, prettiify: bool = False, parser_type: str = "lxml", - scraping_strategy: ContentScrapingStrategy = None, + scraping_strategy: Optional[ContentScrapingStrategy] = None, proxy_config: Union[ProxyConfig, dict, None] = None, proxy_rotation_strategy: Optional[ProxyRotationStrategy] = None, # SSL Parameters fetch_ssl_certificate: bool = False, # Caching Parameters cache_mode: CacheMode = CacheMode.BYPASS, - session_id: str = None, + session_id: Optional[str] = None, bypass_cache: bool = False, disable_cache: bool = False, no_cache_read: bool = False, no_cache_write: bool = False, - shared_data: dict = None, + shared_data: Optional[dict] = None, # Page Navigation and Timing Parameters wait_until: str = "domcontentloaded", page_timeout: int = PAGE_TIMEOUT, - wait_for: str = None, + wait_for: Optional[str] = None, wait_for_images: bool = False, delay_before_return_html: float = 0.1, mean_delay: float = 0.1, max_range: float = 0.3, semaphore_count: int = 5, # Page Interaction Parameters - js_code: Union[str, List[str]] = None, + js_code: Optional[Union[str, List[str]]] = None, js_only: bool = False, ignore_body_visibility: bool = True, scan_full_page: bool = False, @@ -765,7 +771,7 @@ def __init__( adjust_viewport_to_content: bool = False, # Media Handling Parameters screenshot: bool = False, - screenshot_wait_for: float = None, + screenshot_wait_for: Optional[float] = None, screenshot_height_threshold: int = SCREENSHOT_HEIGHT_TRESHOLD, pdf: bool = False, image_description_min_word_threshold: int = IMAGE_DESCRIPTION_MIN_WORD_THRESHOLD, @@ -773,10 +779,10 @@ def __init__( table_score_threshold: int = 7, exclude_external_images: bool = False, # Link and Domain Handling Parameters - exclude_social_media_domains: list = None, + exclude_social_media_domains: Optional[list] = None, exclude_external_links: bool = False, exclude_social_media_links: bool = False, - exclude_domains: list = None, + exclude_domains: Optional[list] = None, exclude_internal_links: bool = False, # Debugging and Logging Parameters verbose: bool = True, @@ -784,15 +790,15 @@ def __init__( # Connection Parameters method: str = "GET", stream: bool = False, - url: str = None, + url: Optional[str] = None, check_robots_txt: bool = False, - user_agent: str = None, - user_agent_mode: str = None, + user_agent: Optional[str] = None, + user_agent_mode: Optional[str] = None, user_agent_generator_config: dict = {}, # Deep Crawl Parameters deep_crawl_strategy: Optional[DeepCrawlStrategy] = None, # Experimental Parameters - experimental: Dict[str, Any] = None, + experimental: Optional[Dict[str, Any]] = None, ): # TODO: Planning to set properties dynamically based on the __init__ signature self.url = url @@ -1021,15 +1027,18 @@ def from_kwargs(kwargs: dict) -> "CrawlerRunConfig": experimental=kwargs.get("experimental"), ) - # Create a funciton returns dict of the object - def dump(self) -> dict: + # Create a function returns dict of the object + def dump(self) -> Serialisable: # Serialize the object to a dictionary return to_serializable_dict(self) @staticmethod - def load(data: dict) -> "CrawlerRunConfig": + def load(data: Serialisable) -> "CrawlerRunConfig": + if data is None: + return CrawlerRunConfig() + # Deserialize the object from a dictionary - config = from_serializable_dict(data) + config = from_serializable(data) if isinstance(config, CrawlerRunConfig): return config return CrawlerRunConfig.from_kwargs(config) @@ -1139,15 +1148,15 @@ def __init__( provider: str = DEFAULT_PROVIDER, api_token: Optional[str] = None, base_url: Optional[str] = None, - temprature: Optional[float] = None, + temperature: Optional[float] = None, max_tokens: Optional[int] = None, top_p: Optional[float] = None, frequency_penalty: Optional[float] = None, presence_penalty: Optional[float] = None, stop: Optional[List[str]] = None, - n: Optional[int] = None, + n: Optional[int] = None, ): - """Configuaration class for LLM provider and API token.""" + """Configuration class for LLM provider and API token.""" self.provider = provider if api_token and not api_token.startswith("env:"): self.api_token = api_token @@ -1158,7 +1167,7 @@ def __init__( DEFAULT_PROVIDER_API_KEY ) self.base_url = base_url - self.temprature = temprature + self.temperature = temperature self.max_tokens = max_tokens self.top_p = top_p self.frequency_penalty = frequency_penalty @@ -1167,12 +1176,12 @@ def __init__( self.n = n @staticmethod - def from_kwargs(kwargs: dict) -> "LLMConfig": + def from_kwargs(kwargs: dict) -> LLMConfig: return LLMConfig( provider=kwargs.get("provider", DEFAULT_PROVIDER), api_token=kwargs.get("api_token"), base_url=kwargs.get("base_url"), - temprature=kwargs.get("temprature"), + temperature=kwargs.get("temperature"), max_tokens=kwargs.get("max_tokens"), top_p=kwargs.get("top_p"), frequency_penalty=kwargs.get("frequency_penalty"), @@ -1186,7 +1195,7 @@ def to_dict(self): "provider": self.provider, "api_token": self.api_token, "base_url": self.base_url, - "temprature": self.temprature, + "temperature": self.temperature, "max_tokens": self.max_tokens, "top_p": self.top_p, "frequency_penalty": self.frequency_penalty, diff --git a/crawl4ai/async_crawler_strategy.py b/crawl4ai/async_crawler_strategy.py index 37aa0962f..4cebc094f 100644 --- a/crawl4ai/async_crawler_strategy.py +++ b/crawl4ai/async_crawler_strategy.py @@ -4,10 +4,10 @@ import base64 import time from abc import ABC, abstractmethod -from typing import Callable, Dict, Any, List, Union -from typing import Optional, AsyncGenerator, Final +from typing import Callable, Dict, Any, List, Union, Self +from typing import Optional, AsyncGenerator, Final, Coroutine import os -from playwright.async_api import Page, Error +from playwright.async_api import Page, Error, Download from playwright.async_api import TimeoutError as PlaywrightTimeoutError from io import BytesIO from PIL import Image, ImageDraw, ImageFont @@ -17,7 +17,7 @@ from .models import AsyncCrawlResponse from .config import SCREENSHOT_HEIGHT_TRESHOLD from .async_configs import BrowserConfig, CrawlerRunConfig, HTTPCrawlerConfig -from .async_logger import AsyncLogger +from .async_logger import AsyncLogger, AsyncLoggerBase from .ssl_certificate import SSLCertificate from .user_agent_generator import ValidUAGenerator from .browser_manager import BrowserManager @@ -39,7 +39,28 @@ class AsyncCrawlerStrategy(ABC): @abstractmethod async def crawl(self, url: str, **kwargs) -> AsyncCrawlResponse: - pass # 4 + 3 + """Crawl a given URL and return the response.""" + + @abstractmethod + def set_hook(self, hook_type: str, hook: Callable): + """Set a hook function for a specific hook type.""" + + @abstractmethod + def update_user_agent(self, user_agent: str): + """Update the user agent for requests.""" + + @abstractmethod + def set_custom_headers(self, headers: Dict[str, str]): + """Set custom headers for requests.""" + + @abstractmethod + async def __aenter__(self) -> Self: + """Enter the context manager and start the crawler.""" + + @abstractmethod + async def __aexit__(self, exc_type, exc_val, exc_tb): + """Exit the context manager and clean up resources.""" + class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): """ @@ -47,7 +68,7 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): Attributes: browser_config (BrowserConfig): Configuration object containing browser settings. - logger (AsyncLogger): Logger instance for recording events and errors. + logger (AsyncLoggerBase): Logger instance for recording events and errors. _downloaded_files (List[str]): List of downloaded file paths. hooks (Dict[str, Callable]): Dictionary of hooks for custom behavior. browser_manager (BrowserManager): Manager for browser creation and management. @@ -71,7 +92,10 @@ class AsyncPlaywrightCrawlerStrategy(AsyncCrawlerStrategy): """ def __init__( - self, browser_config: BrowserConfig = None, logger: AsyncLogger = None, **kwargs + self, + browser_config: Optional[BrowserConfig] = None, + logger: Optional[AsyncLoggerBase] = None, + **kwargs, ): """ Initialize the AsyncPlaywrightCrawlerStrategy with a browser configuration. @@ -84,13 +108,15 @@ def __init__( """ # Initialize browser config, either from provided object or kwargs self.browser_config = browser_config or BrowserConfig.from_kwargs(kwargs) - self.logger = logger + self.logger = logger or AsyncLogger( + verbose=self.browser_config.verbose, + ) # Initialize session management - self._downloaded_files = [] + self._download_tasks: list[Coroutine] = [] # Initialize hooks system - self.hooks = { + self.hooks: dict[str, Optional[Callable]] = { "on_browser_created": None, "on_page_context_created": None, "on_user_agent_updated": None, @@ -107,7 +133,7 @@ def __init__( browser_config=self.browser_config, logger=self.logger ) - async def __aenter__(self): + async def __aenter__(self) -> Self: await self.start() return self @@ -204,7 +230,7 @@ def update_user_agent(self, user_agent: str): Returns: None """ - self.user_agent = user_agent + self.browser_config.user_agent = user_agent def set_custom_headers(self, headers: Dict[str, str]): """ @@ -216,7 +242,7 @@ def set_custom_headers(self, headers: Dict[str, str]): Returns: None """ - self.headers = headers + self.browser_config.headers = headers async def smart_wait(self, page: Page, wait_for: str, timeout: float = 30000): """ @@ -446,6 +472,8 @@ async def crawl( with open(local_file_path, "r", encoding="utf-8") as f: html = f.read() if config.screenshot: + # TODO: Fix as this method has been removed, see: + # https://github.com/unclecode/crawl4ai/commit/c38ac29edbcebcb2f3672145424e7af3193caa6e#diff-ede9fd3357068d6fca3803250d507a6816ed8476e6fa71990948bcedb1c460fbR1073 screenshot_data = await self._generate_screenshot_from_html(html) return AsyncCrawlResponse( html=html, @@ -460,6 +488,8 @@ async def crawl( raw_html = url[4:] if url[:4] == "raw:" else url[7:] html = raw_html if config.screenshot: + # TODO: Fix as this method has been removed, see: + # https://github.com/unclecode/crawl4ai/commit/c38ac29edbcebcb2f3672145424e7af3193caa6e#diff-ede9fd3357068d6fca3803250d507a6816ed8476e6fa71990948bcedb1c460fbR1073 screenshot_data = await self._generate_screenshot_from_html(html) return AsyncCrawlResponse( html=html, @@ -490,7 +520,7 @@ async def _crawl_web( response_headers = {} execution_result = None status_code = None - redirected_url = url + redirected_url = url # Reset downloaded files list for new crawl self._downloaded_files = [] @@ -524,14 +554,19 @@ async def _crawl_web( # Set up console logging if requested if config.log_console: - def log_consol( - msg, console_log_type="debug" - ): # Corrected the parameter syntax + def log_consol(msg, console_log_type="debug"): # Corrected the parameter syntax if console_log_type == "error": + text: str = "unknown" + if isinstance(msg, Error): + text = msg.message + elif isinstance(msg, str): + text = msg + elif hasattr(msg, "text"): + text = msg.text self.logger.error( message=f"Console error: {msg}", # Use f-string for variable interpolation tag="CONSOLE", - params={"msg": msg.text}, + params={"msg": text}, ) elif console_log_type == "debug": self.logger.debug( @@ -607,8 +642,8 @@ def log_consol( const element = document.body; if (!element) return false; const style = window.getComputedStyle(element); - const isVisible = style.display !== 'none' && - style.visibility !== 'hidden' && + const isVisible = style.display !== 'none' && + style.visibility !== 'hidden' && style.opacity !== '0'; return isVisible; }""", @@ -815,21 +850,21 @@ def log_consol( # Handle comma-separated selectors by splitting them selectors = [s.strip() for s in config.css_selector.split(',')] html_parts = [] - + for selector in selectors: try: content = await page.evaluate(f"document.querySelector('{selector}')?.outerHTML || ''") html_parts.append(content) except Error as e: print(f"Warning: Could not get content for selector '{selector}': {str(e)}") - + # Wrap in a div to create a valid HTML structure - html = f"
\n" + "\n".join(html_parts) + "\n
" + html = "
\n" + "\n".join(html_parts) + "\n
" except Error as e: raise RuntimeError(f"Failed to extract HTML content: {str(e)}") else: html = await page.content() - + # # Get final HTML content # html = await page.content() await self.execute_hook( @@ -868,6 +903,13 @@ async def get_delayed_content(delay: float = 5.0) -> str: await asyncio.sleep(delay) return await page.content() + async def gather_downloads() -> Optional[List[str]]: + """Gather all download tasks and return the list of downloaded files.""" + if not self._download_tasks: + return None + + return [download for download in await asyncio.gather(*self._download_tasks) if download is not None] + # Return complete response return AsyncCrawlResponse( html=html, @@ -878,9 +920,7 @@ async def get_delayed_content(delay: float = 5.0) -> str: pdf_data=pdf_data, get_delayed_content=get_delayed_content, ssl_certificate=ssl_cert, - downloaded_files=( - self._downloaded_files if self._downloaded_files else None - ), + downloaded_files=await gather_downloads(), redirected_url=redirected_url, ) @@ -957,7 +997,7 @@ async def _handle_full_page_scan(self, page: Page, scroll_delay: float = 0.1): # await page.evaluate("window.scrollTo(0, document.body.scrollHeight)") await self.safe_scroll(page, 0, total_height) - async def _handle_download(self, download): + async def _handle_download(self, download: Download): """ Handle file downloads. @@ -975,36 +1015,45 @@ async def _handle_download(self, download): Returns: None """ - try: - suggested_filename = download.suggested_filename - download_path = os.path.join(self.browser_config.downloads_path, suggested_filename) + suggested_filename = download.suggested_filename + download_path: str = ( + suggested_filename + if not self.browser_config.downloads_path + else os.path.join(self.browser_config.downloads_path, suggested_filename) + ) - self.logger.info( - message="Downloading {filename} to {path}", - tag="FETCH", - params={"filename": suggested_filename, "path": download_path}, - ) + self.logger.info( + message="Downloading {filename} to {path}", + tag="FETCH", + params={"filename": suggested_filename, "path": download_path}, + ) - start_time = time.perf_counter() - await download.save_as(download_path) - end_time = time.perf_counter() - self._downloaded_files.append(download_path) - - self.logger.success( - message="Downloaded {filename} successfully", - tag="COMPLETE", - params={ - "filename": suggested_filename, - "path": download_path, - "duration": f"{end_time - start_time:.2f}s", - }, - ) - except Exception as e: - self.logger.error( - message="Failed to handle download: {error}", - tag="ERROR", - params={"error": str(e)}, - ) + start_time = time.perf_counter() + + async def download_task(download_path: str) -> Optional[str]: + try: + await download.save_as(download_path) + end_time = time.perf_counter() + + self.logger.success( + message="Downloaded {filename} successfully", + tag="COMPLETE", + params={ + "filename": suggested_filename, + "path": download_path, + "duration": f"{end_time - start_time:.2f}s", + }, + ) + return download_path + except Exception as e: + self.logger.error( + message="Failed to handle download: {error}", + tag="ERROR", + params={"error": str(e)}, + ) + return None + + self._download_tasks.append(download_task(download_path)) async def remove_overlay_elements(self, page: Page) -> None: """ @@ -1221,7 +1270,7 @@ async def take_screenshot_naive(self, page: Page) -> str: finally: await page.close() - async def export_storage_state(self, path: str = None) -> dict: + async def export_storage_state(self, path: Optional[str] = None) -> Optional[dict]: """ Exports the current storage state (cookies, localStorage, sessionStorage) to a JSON file at the specified path. @@ -1277,6 +1326,7 @@ async def robust_execute_user_script( results = [] for script in scripts: + script = script.strip(";") try: # Attempt the evaluate # If the user code triggers navigation, we catch the "context destroyed" error @@ -1287,8 +1337,8 @@ async def robust_execute_user_script( f""" (async () => {{ try {{ - const script_result = {script}; - return {{ success: true, result: script_result }}; + {script}; + return {{ success: true }}; }} catch (err) {{ return {{ success: false, error: err.toString(), stack: err.stack }}; }} @@ -1335,7 +1385,6 @@ async def robust_execute_user_script( result = {"success": False, "error": str(e)} # If we made it this far with no repeated error, do post-load waits - t1 = time.time() try: await page.wait_for_load_state("domcontentloaded", timeout=5000) except Error as e: @@ -1345,17 +1394,6 @@ async def robust_execute_user_script( params={"error": str(e)}, ) - # t1 = time.time() - # try: - # await page.wait_for_load_state('networkidle', timeout=5000) - # print("Network idle after script execution in", time.time() - t1) - # except Error as e: - # self.logger.warning( - # message="Network idle timeout: {error}", - # tag="JS_EXEC", - # params={"error": str(e)} - # ) - results.append(result if result else {"success": True}) except Exception as e: @@ -1412,7 +1450,7 @@ async def execute_user_script( const result = (function() {{ {script} }})(); - + // If result is a promise, wait for it if (result instanceof Promise) {{ result.then(() => {{ @@ -1442,11 +1480,7 @@ async def execute_user_script( ) # Wait for network idle after script execution - t1 = time.time() await page.wait_for_load_state("domcontentloaded", timeout=5000) - - - t1 = time.time() await page.wait_for_load_state("networkidle", timeout=5000) results.append(result if result else {"success": True}) @@ -1470,14 +1504,6 @@ async def execute_user_script( ) return {"success": False, "error": str(e)} - except Exception as e: - self.logger.error( - message="Script execution failed: {error}", - tag="JS_EXEC", - params={"error": str(e)}, - ) - return {"success": False, "error": str(e)} - async def check_visibility(self, page): """ Checks if an element is visible on the page. @@ -1494,8 +1520,8 @@ async def check_visibility(self, page): const element = document.body; if (!element) return false; const style = window.getComputedStyle(element); - const isVisible = style.display !== 'none' && - style.visibility !== 'hidden' && + const isVisible = style.display !== 'none' && + style.visibility !== 'hidden' && style.opacity !== '0'; return isVisible; } @@ -1535,11 +1561,11 @@ async def csp_scroll_to(self, page: Page, x: int, y: int) -> Dict[str, Any]: const startX = window.scrollX; const startY = window.scrollY; window.scrollTo({x}, {y}); - + // Get final position after scroll const endX = window.scrollX; const endY = window.scrollY; - + return {{ success: true, startPosition: {{ x: startX, y: startY }}, @@ -1650,11 +1676,11 @@ class AsyncHTTPCrawlerStrategy(AsyncCrawlerStrategy): """ Fast, lightweight HTTP-only crawler strategy optimized for memory efficiency. """ - + __slots__ = ('logger', 'max_connections', 'dns_cache_ttl', 'chunk_size', '_session', 'hooks', 'browser_config') DEFAULT_TIMEOUT: Final[int] = 30 - DEFAULT_CHUNK_SIZE: Final[int] = 64 * 1024 + DEFAULT_CHUNK_SIZE: Final[int] = 64 * 1024 DEFAULT_MAX_CONNECTIONS: Final[int] = min(32, (os.cpu_count() or 1) * 4) DEFAULT_DNS_CACHE_TTL: Final[int] = 300 VALID_SCHEMES: Final = frozenset({'http', 'https', 'file', 'raw'}) @@ -1667,9 +1693,9 @@ class AsyncHTTPCrawlerStrategy(AsyncCrawlerStrategy): 'Upgrade-Insecure-Requests': '1', 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36' }) - + def __init__( - self, + self, browser_config: Optional[HTTPCrawlerConfig] = None, logger: Optional[AsyncLogger] = None, max_connections: int = DEFAULT_MAX_CONNECTIONS, @@ -1683,9 +1709,10 @@ def __init__( self.dns_cache_ttl = dns_cache_ttl self.chunk_size = chunk_size self._session: Optional[aiohttp.ClientSession] = None - + self._base_headers = self._BASE_HEADERS.copy() + self.hooks = { - k: partial(self._execute_hook, k) + k: partial(self._execute_hook, k) for k in ('before_request', 'after_request', 'on_error') } @@ -1693,12 +1720,13 @@ def __init__( self.set_hook('before_request', lambda *args, **kwargs: None) self.set_hook('after_request', lambda *args, **kwargs: None) self.set_hook('on_error', lambda *args, **kwargs: None) - - async def __aenter__(self) -> AsyncHTTPCrawlerStrategy: + + + async def __aenter__(self) -> Self: await self.start() return self - + async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: await self.close() @@ -1711,17 +1739,41 @@ async def _session_context(self): finally: await self.close() - def set_hook(self, hook_type: str, hook_func: Callable) -> None: + def set_hook(self, hook_type: str, hook: Callable) -> None: if hook_type in self.hooks: - self.hooks[hook_type] = partial(self._execute_hook, hook_type, hook_func) + self.hooks[hook_type] = partial(self._execute_hook, hook_type, hook) else: raise ValueError(f"Invalid hook type: {hook_type}") + def update_user_agent(self, user_agent: str): + """ + Update the user agent for requests + + Args: + user_agent (str): The new user agent string. + + Returns: + None + """ + self._base_headers["User-Agent"] = user_agent + + def set_custom_headers(self, headers: Dict[str, str]): + """ + Set custom headers for requests. + + Args: + headers (Dict[str, str]): A dictionary of headers to set. + + Returns: + None + """ + self._base_headers.update(headers) + async def _execute_hook( - self, - hook_type: str, + self, + hook_type: str, hook_func: Callable, - *args: Any, + *args: Any, **kwargs: Any ) -> Any: if asyncio.iscoroutinefunction(hook_func): @@ -1737,7 +1789,7 @@ async def start(self) -> None: force_close=False ) self._session = aiohttp.ClientSession( - headers=dict(self._BASE_HEADERS), + headers=dict(self._base_headers), connector=connector, timeout=ClientTimeout(total=self.DEFAULT_TIMEOUT) ) @@ -1763,11 +1815,11 @@ async def _stream_file(self, path: str) -> AsyncGenerator[memoryview, None]: async def _handle_file(self, path: str) -> AsyncCrawlResponse: if not os.path.exists(path): raise FileNotFoundError(f"Local file not found: {path}") - + chunks = [] async for chunk in self._stream_file(path): chunks.append(chunk.tobytes().decode('utf-8', errors='replace')) - + return AsyncCrawlResponse( html=''.join(chunks), response_headers={}, @@ -1783,8 +1835,8 @@ async def _handle_raw(self, content: str) -> AsyncCrawlResponse: async def _handle_http( - self, - url: str, + self, + url: str, config: CrawlerRunConfig ) -> AsyncCrawlResponse: async with self._session_context() as session: @@ -1793,8 +1845,8 @@ async def _handle_http( connect=10, sock_read=30 ) - - headers = dict(self._BASE_HEADERS) + + headers = dict(self._base_headers) if self.browser_config.headers: headers.update(self.browser_config.headers) @@ -1816,69 +1868,69 @@ async def _handle_http( try: async with session.request(self.browser_config.method, url, **request_kwargs) as response: content = memoryview(await response.read()) - + if not (200 <= response.status < 300): raise HTTPStatusError( response.status, f"Unexpected status code for {url}" ) - + encoding = response.charset if not encoding: - encoding = cchardet.detect(content.tobytes())['encoding'] or 'utf-8' - + encoding = cchardet.detect(content.tobytes())['encoding'] or 'utf-8' + result = AsyncCrawlResponse( html=content.tobytes().decode(encoding, errors='replace'), response_headers=dict(response.headers), status_code=response.status, redirected_url=str(response.url) ) - + await self.hooks['after_request'](result) return result except aiohttp.ServerTimeoutError as e: await self.hooks['on_error'](e) raise ConnectionTimeoutError(f"Request timed out: {str(e)}") - + except aiohttp.ClientConnectorError as e: await self.hooks['on_error'](e) raise ConnectionError(f"Connection failed: {str(e)}") - + except aiohttp.ClientError as e: await self.hooks['on_error'](e) raise HTTPCrawlerError(f"HTTP client error: {str(e)}") - + except asyncio.exceptions.TimeoutError as e: await self.hooks['on_error'](e) raise ConnectionTimeoutError(f"Request timed out: {str(e)}") - + except Exception as e: await self.hooks['on_error'](e) raise HTTPCrawlerError(f"HTTP request failed: {str(e)}") async def crawl( - self, - url: str, - config: Optional[CrawlerRunConfig] = None, + self, + url: str, + config: Optional[CrawlerRunConfig] = None, **kwargs ) -> AsyncCrawlResponse: config = config or CrawlerRunConfig.from_kwargs(kwargs) - + parsed = urlparse(url) scheme = parsed.scheme.rstrip('/') - + if scheme not in self.VALID_SCHEMES: raise ValueError(f"Unsupported URL scheme: {scheme}") - + try: if scheme == 'file': return await self._handle_file(parsed.path) elif scheme == 'raw': - return await self._handle_raw(parsed.path) + return await self._handle_raw(url.removeprefix('raw://')) else: # http or https return await self._handle_http(url, config) - + except Exception as e: if self.logger: self.logger.error( diff --git a/crawl4ai/async_database.py b/crawl4ai/async_database.py index 870350e9c..9d613fac9 100644 --- a/crawl4ai/async_database.py +++ b/crawl4ai/async_database.py @@ -2,7 +2,7 @@ from pathlib import Path import aiosqlite import asyncio -from typing import Optional, Dict +from typing import Optional, Dict, TypeVar, Callable, Awaitable from contextlib import asynccontextmanager import json from .models import CrawlResult, MarkdownGenerationResult, StringCompatibleMarkdown @@ -19,6 +19,10 @@ os.makedirs(DB_PATH, exist_ok=True) DB_PATH = os.path.join(base_directory, "crawl4ai.db") +R = TypeVar("R") + +class MalformedTableError(Exception): + """Raised when a table is missing required columns.""" class AsyncDatabaseManager: def __init__(self, pool_size: int = 10, max_retries: int = 3): @@ -134,30 +138,38 @@ async def get_connection(self): await conn.execute("PRAGMA busy_timeout = 5000") # Verify database structure - async with conn.execute( - "PRAGMA table_info(crawled_data)" - ) as cursor: - columns = await cursor.fetchall() - column_names = [col[1] for col in columns] - expected_columns = { - "url", - "html", - "cleaned_html", - "markdown", - "extracted_content", - "success", - "media", - "links", - "metadata", - "screenshot", - "response_headers", - "downloaded_files", - } - missing_columns = expected_columns - set(column_names) - if missing_columns: - raise ValueError( - f"Database missing columns: {missing_columns}" - ) + retry: bool = True + while retry: + async with conn.execute( + "PRAGMA table_info(crawled_data)" + ) as cursor: + columns = await cursor.fetchall() + if not columns: + # Table doesn't exist, reinitialize. + await self.initialize() + continue + + retry = False + column_names = [col[1] for col in columns] + expected_columns = { + "url", + "html", + "cleaned_html", + "markdown", + "extracted_content", + "success", + "media", + "links", + "metadata", + "screenshot", + "response_headers", + "downloaded_files", + } + missing_columns = expected_columns - set(column_names) + if missing_columns: + raise MalformedTableError( + f"Database missing columns: {missing_columns}" + ) self.connection_pool[task_id] = conn except Exception as e: @@ -199,14 +211,23 @@ async def get_connection(self): del self.connection_pool[task_id] self.connection_semaphore.release() - async def execute_with_retry(self, operation, *args): + async def execute_with_retry( + self, operation: Callable[[aiosqlite.Connection], Awaitable[R]] + ) -> Optional[R]: """Execute database operations with retry logic""" for attempt in range(self.max_retries): try: async with self.get_connection() as db: - result = await operation(db, *args) - await db.commit() - return result + return await operation(db) + except MalformedTableError as e: + # Table is malformed, no point in retrying. + self.logger.error( + message="Operation failed after {retries} attempts: {error}", + tag="ERROR", + force_verbose=True, + params={"retries": self.max_retries, "error": str(e)}, + ) + raise except Exception as e: if attempt == self.max_retries - 1: self.logger.error( @@ -218,6 +239,8 @@ async def execute_with_retry(self, operation, *args): raise await asyncio.sleep(1 * (attempt + 1)) # Exponential backoff + return None + async def ainit_db(self): """Initialize database schema""" async with aiosqlite.connect(self.db_path, timeout=30.0) as db: @@ -282,7 +305,7 @@ async def aalter_db_add_column(self, new_column: str, db): async def aget_cached_url(self, url: str) -> Optional[CrawlResult]: """Retrieve cached URL data as CrawlResult""" - async def _get(db): + async def _get(db: aiosqlite.Connection) -> Optional[CrawlResult]: async with db.execute( "SELECT * FROM crawled_data WHERE url = ?", (url,) ) as cursor: @@ -341,10 +364,6 @@ async def _get(db): else: row_dict[field] = {} - if isinstance(row_dict["markdown"], Dict): - if row_dict["markdown"].get("raw_markdown"): - row_dict["markdown"] = row_dict["markdown"]["raw_markdown"] - # Parse downloaded_files try: row_dict["downloaded_files"] = ( @@ -386,7 +405,7 @@ async def acache_url(self, result: CrawlResult): try: if isinstance(result.markdown, StringCompatibleMarkdown): content_map["markdown"] = ( - result.markdown, + result.markdown.markdown_result.model_dump_json(), "markdown", ) elif isinstance(result.markdown, MarkdownGenerationResult): @@ -419,7 +438,7 @@ async def acache_url(self, result: CrawlResult): for field, (content, content_type) in content_map.items(): content_hashes[field] = await self._store_content(content, content_type) - async def _cache(db): + async def _cache(db: aiosqlite.Connection) -> None: await db.execute( """ INSERT INTO crawled_data ( @@ -456,6 +475,7 @@ async def _cache(db): json.dumps(result.downloaded_files or []), ), ) + await db.commit() try: await self.execute_with_retry(_cache) @@ -470,13 +490,14 @@ async def _cache(db): async def aget_total_count(self) -> int: """Get total number of cached URLs""" - async def _count(db): + async def _count(db: aiosqlite.Connection) -> int: async with db.execute("SELECT COUNT(*) FROM crawled_data") as cursor: result = await cursor.fetchone() return result[0] if result else 0 try: - return await self.execute_with_retry(_count) + result: Optional[int] = await self.execute_with_retry(_count) + return result or 0 except Exception as e: self.logger.error( message="Error getting total count: {error}", @@ -489,8 +510,9 @@ async def _count(db): async def aclear_db(self): """Clear all data from the database""" - async def _clear(db): + async def _clear(db: aiosqlite.Connection) -> None: await db.execute("DELETE FROM crawled_data") + await db.commit() try: await self.execute_with_retry(_clear) @@ -505,8 +527,9 @@ async def _clear(db): async def aflush_db(self): """Drop the entire table""" - async def _flush(db): + async def _flush(db: aiosqlite.Connection) -> None: await db.execute("DROP TABLE IF EXISTS crawled_data") + await db.commit() try: await self.execute_with_retry(_flush) @@ -544,7 +567,7 @@ async def _load_content( try: async with aiofiles.open(file_path, "r", encoding="utf-8") as f: return await f.read() - except: + except Exception: self.logger.error( message="Failed to load content: {file_path}", tag="ERROR", diff --git a/crawl4ai/async_dispatcher.py b/crawl4ai/async_dispatcher.py index b97d59a7b..67bcfae22 100644 --- a/crawl4ai/async_dispatcher.py +++ b/crawl4ai/async_dispatcher.py @@ -2,6 +2,7 @@ from .async_configs import CrawlerRunConfig from .models import ( CrawlResult, + CrawlResultContainer, CrawlerTaskResult, CrawlStatus, DomainState, @@ -29,7 +30,7 @@ def __init__( base_delay: Tuple[float, float] = (1.0, 3.0), max_delay: float = 60.0, max_retries: int = 3, - rate_limit_codes: List[int] = None, + rate_limit_codes: Optional[List[int]] = None, ): self.base_delay = base_delay self.max_delay = max_delay @@ -90,7 +91,6 @@ def __init__( rate_limiter: Optional[RateLimiter] = None, monitor: Optional[CrawlerMonitor] = None, ): - self.crawler = None self._domain_last_hit: Dict[str, float] = {} self.concurrent_sessions = 0 self.rate_limiter = rate_limiter @@ -99,10 +99,10 @@ def __init__( @abstractmethod async def crawl_url( self, + crawler: AsyncWebCrawler, url: str, config: CrawlerRunConfig, task_id: str, - monitor: Optional[CrawlerMonitor] = None, ) -> CrawlerTaskResult: pass @@ -110,12 +110,20 @@ async def crawl_url( async def run_urls( self, urls: List[str], - crawler: AsyncWebCrawler, # noqa: F821 + crawler: AsyncWebCrawler, config: CrawlerRunConfig, - monitor: Optional[CrawlerMonitor] = None, ) -> List[CrawlerTaskResult]: pass + @abstractmethod + async def run_urls_stream( + self, + urls: List[str], + crawler: AsyncWebCrawler, + config: CrawlerRunConfig, + ) -> AsyncGenerator[CrawlerTaskResult, None]: + yield NotImplemented + class MemoryAdaptiveDispatcher(BaseDispatcher): def __init__( @@ -179,6 +187,7 @@ def _get_priority_score(self, wait_time: float, retry_count: int) -> float: async def crawl_url( self, + crawler: AsyncWebCrawler, url: str, config: CrawlerRunConfig, task_id: str, @@ -225,9 +234,11 @@ async def crawl_url( return CrawlerTaskResult( task_id=task_id, url=url, - result=CrawlResult( - url=url, html="", metadata={"status": "requeued"}, - success=False, error_message="Requeued due to critical memory pressure" + result=CrawlResultContainer( + CrawlResult( + url=url, html="", metadata={"status": "requeued"}, + success=False, error_message="Requeued due to critical memory pressure" + ) ), memory_usage=0, peak_memory=0, @@ -238,8 +249,8 @@ async def crawl_url( ) # Execute the crawl - result = await self.crawler.arun(url, config=config, session_id=task_id) - + result: CrawlResultContainer = await crawler.arun(url, config=config, session_id=task_id) + # Measure memory usage end_memory = process.memory_info().rss / (1024 * 1024) memory_usage = peak_memory = end_memory - start_memory @@ -263,8 +274,10 @@ async def crawl_url( error_message = str(e) if self.monitor: self.monitor.update_task(task_id, status=CrawlStatus.FAILED) - result = CrawlResult( - url=url, html="", metadata={}, success=False, error_message=str(e) + result = CrawlResultContainer( + CrawlResult( + url=url, html="", metadata={}, success=False, error_message=str(e) + ) ) finally: @@ -288,7 +301,7 @@ async def crawl_url( peak_memory=peak_memory, start_time=start_time, end_time=end_time, - error_message=error_message, + error_message=error_message or "", retry_count=retry_count ) @@ -298,16 +311,14 @@ async def run_urls( crawler: AsyncWebCrawler, config: CrawlerRunConfig, ) -> List[CrawlerTaskResult]: - self.crawler = crawler - # Start the memory monitor task memory_monitor = asyncio.create_task(self._memory_monitor_task()) if self.monitor: self.monitor.start() - - results = [] - + + results: List[CrawlerTaskResult] = [] + try: # Initialize task queue for url in urls: @@ -331,7 +342,7 @@ async def run_urls( # Create and start the task task = asyncio.create_task( - self.crawl_url(url, config, task_id, retry_count) + self.crawl_url(crawler, url, config, task_id, retry_count) ) active_tasks.append(task) @@ -367,9 +378,6 @@ async def run_urls( # Update priorities for waiting tasks if needed await self._update_queue_priorities() - - return results - except Exception as e: if self.monitor: self.monitor.update_memory_status(f"QUEUE_ERROR: {str(e)}") @@ -379,7 +387,9 @@ async def run_urls( memory_monitor.cancel() if self.monitor: self.monitor.stop() - + + return results + async def _update_queue_priorities(self): """Periodically update priorities of items in the queue to prevent starvation""" # Skip if queue is empty @@ -416,8 +426,9 @@ async def _update_queue_priorities(self): break except Exception as e: # If anything goes wrong, make sure we refill the queue with what we've got - self.monitor.update_memory_status(f"QUEUE_ERROR: {str(e)}") - + if self.monitor: + self.monitor.update_memory_status(f"QUEUE_ERROR: {str(e)}") + # Calculate queue statistics if temp_items and self.monitor: total_queued = len(temp_items) @@ -445,8 +456,6 @@ async def run_urls_stream( crawler: AsyncWebCrawler, config: CrawlerRunConfig, ) -> AsyncGenerator[CrawlerTaskResult, None]: - self.crawler = crawler - # Start the memory monitor task memory_monitor = asyncio.create_task(self._memory_monitor_task()) @@ -477,7 +486,7 @@ async def run_urls_stream( # Create and start the task task = asyncio.create_task( - self.crawl_url(url, config, task_id, retry_count) + self.crawl_url(crawler, url, config, task_id, retry_count) ) active_tasks.append(task) @@ -538,16 +547,20 @@ def __init__( async def crawl_url( self, + crawler: AsyncWebCrawler, url: str, config: CrawlerRunConfig, task_id: str, - semaphore: asyncio.Semaphore = None, + semaphore: Optional[asyncio.Semaphore] = None, ) -> CrawlerTaskResult: start_time = time.time() error_message = "" memory_usage = peak_memory = 0.0 try: + if semaphore is None: + raise ValueError(f"Semaphore must be provided to {self.__class__.__name__}") + if self.monitor: self.monitor.update_task( task_id, status=CrawlStatus.IN_PROGRESS, start_time=start_time @@ -559,7 +572,7 @@ async def crawl_url( async with semaphore: process = psutil.Process() start_memory = process.memory_info().rss / (1024 * 1024) - result = await self.crawler.arun(url, config=config, session_id=task_id) + result: CrawlResultContainer = await crawler.arun(url, config=config, session_id=task_id) end_memory = process.memory_info().rss / (1024 * 1024) memory_usage = peak_memory = end_memory - start_memory @@ -591,8 +604,10 @@ async def crawl_url( error_message = str(e) if self.monitor: self.monitor.update_task(task_id, status=CrawlStatus.FAILED) - result = CrawlResult( - url=url, html="", metadata={}, success=False, error_message=str(e) + result = CrawlResultContainer( + CrawlResult( + url=url, html="", metadata={}, success=False, error_message=str(e) + ) ) finally: @@ -614,16 +629,15 @@ async def crawl_url( peak_memory=peak_memory, start_time=start_time, end_time=end_time, - error_message=error_message, + error_message=error_message or "", ) async def run_urls( self, - crawler: AsyncWebCrawler, # noqa: F821 urls: List[str], + crawler: AsyncWebCrawler, config: CrawlerRunConfig, ) -> List[CrawlerTaskResult]: - self.crawler = crawler if self.monitor: self.monitor.start() @@ -636,11 +650,40 @@ async def run_urls( if self.monitor: self.monitor.add_task(task_id, url) task = asyncio.create_task( - self.crawl_url(url, config, task_id, semaphore) + self.crawl_url(crawler, url, config, task_id, semaphore) + ) + tasks.append(task) + + return await asyncio.gather(*tasks) + finally: + if self.monitor: + self.monitor.stop() + + async def run_urls_stream( + self, + urls: List[str], + crawler: AsyncWebCrawler, + config: CrawlerRunConfig, + ) -> AsyncGenerator[CrawlerTaskResult, None]: + if self.monitor: + self.monitor.start() + + try: + semaphore = asyncio.Semaphore(self.semaphore_count) + tasks = [] + + for url in urls: + task_id = str(uuid.uuid4()) + if self.monitor: + self.monitor.add_task(task_id, url) + task = asyncio.create_task( + self.crawl_url(crawler, url, config, task_id, semaphore) ) tasks.append(task) - return await asyncio.gather(*tasks, return_exceptions=True) + for task in asyncio.as_completed(tasks): + result = await task + yield result finally: if self.monitor: - self.monitor.stop() \ No newline at end of file + self.monitor.stop() diff --git a/crawl4ai/async_webcrawler.py b/crawl4ai/async_webcrawler.py index bbee502bb..34180b441 100644 --- a/crawl4ai/async_webcrawler.py +++ b/crawl4ai/async_webcrawler.py @@ -4,19 +4,14 @@ import time from colorama import Fore from pathlib import Path -from typing import Optional, List, Generic, TypeVar +from typing import Optional, List, Self import json import asyncio -# from contextlib import nullcontext, asynccontextmanager from contextlib import asynccontextmanager -from .models import CrawlResult, MarkdownGenerationResult, DispatchResult, ScrapingResult +from .models import CrawlResult, CrawlerTaskResult, MarkdownGenerationResult, DispatchResult, ScrapingResult, CrawlResultContainer from .async_database import async_db_manager -from .chunking_strategy import * # noqa: F403 from .chunking_strategy import IdentityChunking -from .content_filter_strategy import * # noqa: F403 -from .extraction_strategy import * # noqa: F403 -from .extraction_strategy import NoExtractionStrategy from .async_crawler_strategy import ( AsyncCrawlerStrategy, AsyncPlaywrightCrawlerStrategy, @@ -30,7 +25,6 @@ from .deep_crawling import DeepCrawlDecorator from .async_logger import AsyncLogger, AsyncLoggerBase from .async_configs import BrowserConfig, CrawlerRunConfig -from .async_dispatcher import * # noqa: F403 from .async_dispatcher import BaseDispatcher, MemoryAdaptiveDispatcher, RateLimiter from .utils import ( @@ -42,44 +36,7 @@ RobotsParser, ) -from typing import Union, AsyncGenerator - -CrawlResultT = TypeVar('CrawlResultT', bound=CrawlResult) -# RunManyReturn = Union[CrawlResultT, List[CrawlResultT], AsyncGenerator[CrawlResultT, None]] - -class CrawlResultContainer(Generic[CrawlResultT]): - def __init__(self, results: Union[CrawlResultT, List[CrawlResultT]]): - # Normalize to a list - if isinstance(results, list): - self._results = results - else: - self._results = [results] - - def __iter__(self): - return iter(self._results) - - def __getitem__(self, index): - return self._results[index] - - def __len__(self): - return len(self._results) - - def __getattr__(self, attr): - # Delegate attribute access to the first element. - if self._results: - return getattr(self._results[0], attr) - raise AttributeError(f"{self.__class__.__name__} object has no attribute '{attr}'") - - def __repr__(self): - return f"{self.__class__.__name__}({self._results!r})" - -# Redefine the union type. Now synchronous calls always return a container, -# while stream mode is handled with an AsyncGenerator. -RunManyReturn = Union[ - CrawlResultContainer[CrawlResultT], - AsyncGenerator[CrawlResultT, None] -] - +from typing import AsyncGenerator class AsyncWebCrawler: @@ -141,11 +98,11 @@ class AsyncWebCrawler: def __init__( self, - crawler_strategy: AsyncCrawlerStrategy = None, - config: BrowserConfig = None, + crawler_strategy: Optional[AsyncCrawlerStrategy] = None, + config: Optional[BrowserConfig] = None, base_directory: str = str(os.getenv("CRAWL4_AI_BASE_DIRECTORY", Path.home())), thread_safe: bool = False, - logger: AsyncLoggerBase = None, + logger: Optional[AsyncLoggerBase] = None, **kwargs, ): """ @@ -159,7 +116,7 @@ def __init__( **kwargs: Additional arguments for backwards compatibility """ # Handle browser configuration - browser_config = config or BrowserConfig() + browser_config = config or BrowserConfig.from_kwargs(kwargs) self.browser_config = browser_config @@ -245,7 +202,7 @@ async def close(self): """ await self.crawler_strategy.__aexit__(None, None, None) - async def __aenter__(self): + async def __aenter__(self) -> Self: return await self.start() async def __aexit__(self, exc_type, exc_val, exc_tb): @@ -265,20 +222,22 @@ async def awarmup(self): @asynccontextmanager async def nullcontext(self): - """异步空上下文管理器""" + """Async void context manager.""" yield async def arun( self, url: str, - config: CrawlerRunConfig = None, + config: Optional[CrawlerRunConfig] = None, **kwargs, - ) -> RunManyReturn: + ) -> CrawlResultContainer: """ Runs the crawler for a single source: URL (web, local file, or raw HTML). Migration Guide: + Old way (deprecated): + result = await crawler.arun( url="https://example.com", word_count_threshold=200, @@ -287,6 +246,7 @@ async def arun( ) New way (recommended): + config = CrawlerRunConfig( word_count_threshold=200, screenshot=True, @@ -294,19 +254,45 @@ async def arun( ) result = await crawler.arun(url="https://example.com", crawler_config=config) + + If `config.stream` is True it returns a `CrawlResultContainer` which wraps an async generator. + + Example: + + async for result in crawl_result: + print(result.markdown) + + If `config.stream` is `False` and `config.deep_crawl_strategy` is `None` then the + `CrawlResultContainer` will contain a single `CrawlResult` which can be accessed directly. + + Example: + + result = await crawler.arun(url="https://example.com", config=config) + print(result.markdown) + + Otherwise it will contain a list of CrawlResult objects. + + Example: + + for result in crawl_result: + print(result.markdown) + + If `config.deep_crawl_strategy` is not `None` it returns a `CrawlResultContainer` which can + contain multiple `CrawlResult` objects. + Args: - url: The URL to crawl (http://, https://, file://, or raw:) - crawler_config: Configuration object controlling crawl behavior + url (str): The URL to crawl (http://, https://, file://, or raw:) + crawler_config (Optional[CrawlerRunConfig]): Configuration object controlling crawl behavior [other parameters maintained for backwards compatibility] Returns: - CrawlResult: The result of crawling and processing + CrawlResultContainer: The result of crawling and processing """ # Auto-start if not ready if not self.ready: await self.start() - - config = config or CrawlerRunConfig() + + config = config or CrawlerRunConfig.from_kwargs(kwargs) if not isinstance(url, str) or not url: raise ValueError("Invalid URL, make sure the URL is a non-empty string") @@ -324,8 +310,8 @@ async def arun( ) # Initialize processing variables - async_response: AsyncCrawlResponse = None - cached_result: CrawlResult = None + async_response: Optional[AsyncCrawlResponse] = None + cached_result: Optional[CrawlResult] = None screenshot_data = None pdf_data = None extracted_content = None @@ -335,6 +321,7 @@ async def arun( if cache_context.should_read(): cached_result = await async_db_manager.aget_cached_url(url) + html: str = "" if cached_result: html = sanitize_input_encode(cached_result.html) extracted_content = sanitize_input_encode( @@ -384,14 +371,14 @@ async def arun( # Check robots.txt if enabled if config and config.check_robots_txt: if not await self.robots_parser.can_fetch(url, self.browser_config.user_agent): - return CrawlResult( + return CrawlResultContainer(CrawlResult( url=url, html="", success=False, status_code=403, error_message="Access denied by robots.txt", response_headers={"X-Robots-Status": "Blocked by robots.txt"} - ) + )) ############################## # Call CrawlerStrategy.crawl # @@ -399,6 +386,7 @@ async def arun( async_response = await self.crawler_strategy.crawl( url, config=config, # Pass the entire config object + **kwargs, ) html = sanitize_input_encode(async_response.html) @@ -417,7 +405,10 @@ async def arun( ############################################################### # Process the HTML content, Call CrawlerStrategy.process_html # ############################################################### - crawl_result : CrawlResult = await self.aprocess_html( + if "screenshot" in kwargs: + del kwargs["screenshot"] + + crawl_result: CrawlResult = await self.aprocess_html( url=url, html=html, extracted_content=extracted_content, @@ -494,7 +485,7 @@ async def arun( tag="ERROR", ) - return CrawlResultContainer( + return CrawlResultContainer( CrawlResult( url=url, html="", success=False, error_message=error_message ) @@ -504,10 +495,10 @@ async def aprocess_html( self, url: str, html: str, - extracted_content: str, + extracted_content: Optional[str], config: CrawlerRunConfig, - screenshot: str, - pdf_data: str, + screenshot: Optional[str], + pdf_data: Optional[bytes], verbose: bool, **kwargs, ) -> CrawlResult: @@ -526,6 +517,9 @@ async def aprocess_html( Returns: CrawlResult: Processed result containing extracted and formatted content + + Raises: + ValueError: If processing fails """ cleaned_html = "" try: @@ -555,7 +549,9 @@ async def aprocess_html( ) except InvalidCSSSelectorError as e: - raise ValueError(str(e)) + raise ValueError from e + except ValueError: + raise except Exception as e: raise ValueError( f"Process HTML, Failed to extract content from the website: {url}, error: {str(e)}" @@ -599,14 +595,10 @@ async def aprocess_html( params={"url": _url, "timing": int((time.perf_counter() - t1) * 1000) / 1000}, ) - ################################ - # Structured Content Extraction # - ################################ - if ( - not bool(extracted_content) - and config.extraction_strategy - and not isinstance(config.extraction_strategy, NoExtractionStrategy) - ): + ################################# + # Structured Content Extraction # + ################################# + if not extracted_content and config.extraction_strategy: t1 = time.perf_counter() # Choose content based on input_format content_format = config.extraction_strategy.input_format @@ -632,9 +624,11 @@ async def aprocess_html( else config.chunking_strategy ) sections = chunking.chunk(content) - extracted_content = config.extraction_strategy.run(url, sections) extracted_content = json.dumps( - extracted_content, indent=4, default=str, ensure_ascii=False + config.extraction_strategy.run(url, sections), + indent=4, + default=str, + ensure_ascii=False, ) # Log extraction completion @@ -685,8 +679,8 @@ async def arun_many( # pdf: bool = False, # user_agent: str = None, # verbose=True, - **kwargs - ) -> RunManyReturn: + **kwargs, + ) -> CrawlResultContainer: """ Runs the crawler for multiple URLs concurrently using a configurable dispatcher strategy. @@ -697,8 +691,9 @@ async def arun_many( [other parameters maintained for backwards compatibility] Returns: - Union[List[CrawlResult], AsyncGenerator[CrawlResult, None]]: - Either a list of all results or an async generator yielding results + CrawlResultContainer: + A container which encapsulates a list of all results for none streaming requests or + an async generator yielding results for streaming requests. Examples: @@ -717,21 +712,7 @@ async def arun_many( ): print(f"Processed {result.url}: {len(result.markdown)} chars") """ - config = config or CrawlerRunConfig() - # if config is None: - # config = CrawlerRunConfig( - # word_count_threshold=word_count_threshold, - # extraction_strategy=extraction_strategy, - # chunking_strategy=chunking_strategy, - # content_filter=content_filter, - # cache_mode=cache_mode, - # bypass_cache=bypass_cache, - # css_selector=css_selector, - # screenshot=screenshot, - # pdf=pdf, - # verbose=verbose, - # **kwargs, - # ) + config = config or CrawlerRunConfig.from_kwargs(kwargs) if dispatcher is None: dispatcher = MemoryAdaptiveDispatcher( @@ -740,34 +721,51 @@ async def arun_many( ), ) - def transform_result(task_result): - return ( - setattr(task_result.result, 'dispatch_result', - DispatchResult( - task_id=task_result.task_id, - memory_usage=task_result.memory_usage, - peak_memory=task_result.peak_memory, - start_time=task_result.start_time, - end_time=task_result.end_time, - error_message=task_result.error_message, - ) - ) or task_result.result + def transform_result(task_result: CrawlerTaskResult) -> CrawlResult: + """Transform a CrawlerTaskResult into a CrawlResult. + + Attaches the dispatch result to the CrawlResult. + + Args: + task_result (CrawlerTaskResult): The task result to transform. + Returns: + CrawlResult: The transformed crawl result. + Raises: + ValueError: If the task result does not contain exactly one result. + """ + results: CrawlResultContainer = task_result.result + if len(results) != 1: + raise ValueError( + f"Expected a single result, but got {len(results)} results." ) - stream = config.stream - - if stream: - async def result_transformer(): - async for task_result in dispatcher.run_urls_stream(crawler=self, urls=urls, config=config): + result: CrawlResult = results[0] + result.dispatch_result = DispatchResult( + task_id=task_result.task_id, + memory_usage=task_result.memory_usage, + peak_memory=task_result.peak_memory, + start_time=task_result.start_time, + end_time=task_result.end_time, + error_message=task_result.error_message, + ) + + return result + + if config.stream: + async def result_transformer() -> AsyncGenerator[CrawlResult, None]: + async for task_result in dispatcher.run_urls_stream( + urls=urls, crawler=self, config=config + ): yield transform_result(task_result) - return result_transformer() - else: - _results = await dispatcher.run_urls(crawler=self, urls=urls, config=config) - return [transform_result(res) for res in _results] + + return CrawlResultContainer(result_transformer()) + + _results = await dispatcher.run_urls(urls=urls, crawler=self, config=config) + return CrawlResultContainer([transform_result(res) for res in _results]) async def aclear_cache(self): """Clear the cache database.""" - await async_db_manager.cleanup() + await async_db_manager.aclear_db() async def aflush_cache(self): """Flush the cache database.""" diff --git a/crawl4ai/browser/docker_registry.py b/crawl4ai/browser/docker_registry.py index 91f81c5e8..a271df7a4 100644 --- a/crawl4ai/browser/docker_registry.py +++ b/crawl4ai/browser/docker_registry.py @@ -7,7 +7,7 @@ import os import json import time -from typing import Dict, Optional +from typing import Optional from ..utils import get_home_folder diff --git a/crawl4ai/browser/docker_strategy.py b/crawl4ai/browser/docker_strategy.py index 639abd845..600baf6a1 100644 --- a/crawl4ai/browser/docker_strategy.py +++ b/crawl4ai/browser/docker_strategy.py @@ -6,14 +6,11 @@ import os import uuid -import asyncio -from typing import Dict, List, Optional, Tuple, Union -from pathlib import Path +from typing import List, Optional -from playwright.async_api import Page, BrowserContext from ..async_logger import AsyncLogger -from ..async_configs import BrowserConfig, CrawlerRunConfig +from ..async_configs import BrowserConfig from .docker_config import DockerConfig from .docker_registry import DockerRegistry from .docker_utils import DockerUtils diff --git a/crawl4ai/browser/docker_utils.py b/crawl4ai/browser/docker_utils.py index 0597c2d50..cf37ead42 100644 --- a/crawl4ai/browser/docker_utils.py +++ b/crawl4ai/browser/docker_utils.py @@ -5,8 +5,7 @@ import tempfile import shutil import socket -import subprocess -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple class DockerUtils: """Utility class for Docker operations in browser automation. @@ -66,9 +65,9 @@ async def check_image_exists(self, image_name: str) -> bool: if self.logger: self.logger.debug(f"Error checking if image exists: {str(e)}", tag="DOCKER") return False - - async def build_docker_image(self, image_name: str, dockerfile_path: str, - files_to_copy: Dict[str, str] = None) -> bool: + + async def build_docker_image(self, image_name: str, dockerfile_path: str, + files_to_copy: Optional[Dict[str, str]] = None) -> bool: """Build a Docker image from a Dockerfile. Args: @@ -116,8 +115,8 @@ async def build_docker_image(self, image_name: str, dockerfile_path: str, if self.logger: self.logger.success(f"Successfully built Docker image: {image_name}", tag="DOCKER") return True - - async def ensure_docker_image_exists(self, image_name: str, mode: str = "connect") -> str: + + async def ensure_docker_image_exists(self, image_name: Optional[str] = None, mode: str = "connect") -> str: """Ensure the required Docker image exists, creating it if necessary. Args: @@ -174,10 +173,10 @@ async def ensure_docker_image_exists(self, image_name: str, mode: str = "connect async def create_container(self, image_name: str, host_port: int, container_name: Optional[str] = None, - volumes: List[str] = None, + volumes: Optional[List[str]] = None, network: Optional[str] = None, - env_vars: Dict[str, str] = None, - extra_args: List[str] = None) -> Optional[str]: + env_vars: Optional[Dict[str, str]] = None, + extra_args: Optional[List[str]] = None) -> Optional[str]: """Create a new Docker container. Args: diff --git a/crawl4ai/browser/strategies.py b/crawl4ai/browser/strategies.py index f2a9525e0..a8dfd98dd 100644 --- a/crawl4ai/browser/strategies.py +++ b/crawl4ai/browser/strategies.py @@ -13,6 +13,7 @@ import subprocess import shutil import signal +import re from typing import Optional, Dict, Tuple, List, Any from playwright.async_api import BrowserContext, Page, ProxySettings @@ -64,7 +65,7 @@ def __init__(self, config: BrowserConfig, logger: Optional[AsyncLogger] = None): self.playwright = None @abstractmethod - async def start(self): + async def start(self) -> Any: """Start the browser. Returns: @@ -574,8 +575,11 @@ async def start(self): cdp_url = await self._get_or_create_cdp_url() # Connect to the browser using CDP - self.browser = await self.playwright.chromium.connect_over_cdp(cdp_url) - + if self.config.browser_type == "chromium": + self.browser = await self.playwright.chromium.connect_over_cdp(cdp_url) + else: + raise NotImplementedError(f"Browser type {self.config.browser_type} not supported") + # Get or create default context contexts = self.browser.contexts if contexts: @@ -610,6 +614,8 @@ async def _get_or_create_cdp_url(self) -> str: try: # Use DETACHED_PROCESS flag on Windows to fully detach the process # On Unix, we'll use preexec_fn=os.setpgrp to start the process in a new process group + # TODO: fix as using a pipe here can cause the process to hang + # if the pipe becomes full and nothing is reading from it. if is_windows(): self.browser_process = subprocess.Popen( args, @@ -679,15 +685,6 @@ async def _get_browser_args(self, user_data_dir: str) -> List[str]: ] if self.config.headless: args.append("--headless=new") - elif self.config.browser_type == "firefox": - args = [ - "--remote-debugging-port", - str(self.config.debugging_port), - "--profile", - user_data_dir, - ] - if self.config.headless: - args.append("--headless") else: raise NotImplementedError(f"Browser type {self.config.browser_type} not supported") @@ -810,6 +807,10 @@ async def get_page(self, crawlerRunConfig: CrawlerRunConfig) -> Tuple[Page, Brow # For CDP, we typically use the shared default_context context = self.default_context pages = context.pages + + # TODO: Fix this so it doesn't create a new page for the first request + # Where the URL is None or empty and we have a page displaying: + # chrome://new-tab-page/ page = next((p for p in pages if p.url == crawlerRunConfig.url), None) if not page: page = await context.new_page() @@ -894,6 +895,7 @@ def __init__(self, config: BrowserConfig, logger: Optional[AsyncLogger] = None): super().__init__(config, logger) self.builtin_browser_dir = os.path.join(get_home_folder(), "builtin-browser") if not self.config.user_data_dir else self.config.user_data_dir self.builtin_config_file = os.path.join(self.builtin_browser_dir, "browser_config.json") + self.pid: int = 0 # Raise error if user data dir is already engaged if self._check_user_dir_is_engaged(self.builtin_browser_dir): @@ -1016,7 +1018,7 @@ async def launch_builtin_browser(self, """Launch a browser in the background for use as the built-in browser. Args: - browser_type: Type of browser to launch ('chromium' or 'firefox') + browser_type: Type of browser to launch ('chromium') debugging_port: Port to use for CDP debugging headless: Whether to run in headless mode @@ -1053,16 +1055,6 @@ async def launch_builtin_browser(self, ] if headless: args.append("--headless=new") - elif browser_type == "firefox": - args = [ - browser_path, - "--remote-debugging-port", - str(debugging_port), - "--profile", - user_data_dir, - ] - if headless: - args.append("--headless") else: if self.logger: self.logger.error(f"Browser type {browser_type} not supported for built-in browser", tag="BUILTIN") @@ -1093,7 +1085,22 @@ async def launch_builtin_browser(self, if self.logger: self.logger.error(f"Browser process exited immediately with code {process.returncode}", tag="BUILTIN") return None - + + self.pid = process.pid + + if debugging_port == 0: + # Determine the actual debugging port used. + try: + process.communicate(timeout=0.1) + except subprocess.TimeoutExpired as e: + if not e.stderr: + raise Exception("Unable to determine debugging port") from e + match: Optional[re.Match] = re.search(r"ws://[^:]+:(\d+)", e.stderr.decode()) + if not match: + raise Exception("Unable to determine debugging port") from e + debugging_port = int(match.group(1)) + self.config.debugging_port = debugging_port + # Construct CDP URL cdp_url = f"http://localhost:{debugging_port}" @@ -1141,7 +1148,7 @@ async def launch_builtin_browser(self, # Convert legacy format to port mapping elif isinstance(existing_data, dict) and "debugging_port" in existing_data: old_port = str(existing_data.get("debugging_port")) - if self._is_browser_running(existing_data.get("pid")): + if is_browser_running(existing_data.get("pid")): port_map[old_port] = existing_data except Exception as e: if self.logger: @@ -1190,13 +1197,21 @@ async def kill_builtin_browser(self) -> bool: os.kill(pid, signal.SIGTERM) # Wait for termination for _ in range(5): + if self.pid == pid: + # We created the process, so wait for it otherwise it will + # become a zombie process causing retry until SIGKILL. + os.waitpid(pid, os.WNOHANG) + if not is_browser_running(pid): break await asyncio.sleep(0.5) else: # Force kill if still running os.kill(pid, signal.SIGKILL) - + + if self.pid == pid: + self.pid = 0 + # Update config file to remove this browser with open(self.builtin_config_file, 'r') as f: browser_info_dict = json.load(f) @@ -1254,3 +1269,5 @@ async def close(self): # Clean up built-in browser if we created it if self.shutting_down: await self.kill_builtin_browser() + + self.pid = 0 diff --git a/crawl4ai/browser_manager.py b/crawl4ai/browser_manager.py index df0886c75..b7654699e 100644 --- a/crawl4ai/browser_manager.py +++ b/crawl4ai/browser_manager.py @@ -6,7 +6,7 @@ import shutil import tempfile import subprocess -from playwright.async_api import BrowserContext +from playwright.async_api import BrowserContext, async_playwright import hashlib from .js_snippet import load_js_script from .config import DOWNLOAD_PAGE_TIMEOUT @@ -435,15 +435,6 @@ class BrowserManager: session_ttl (int): Session timeout in seconds """ - _playwright_instance = None - - @classmethod - async def get_playwright(cls): - from playwright.async_api import async_playwright - if cls._playwright_instance is None: - cls._playwright_instance = await async_playwright().start() - return cls._playwright_instance - def __init__(self, browser_config: BrowserConfig, logger=None): """ Initialize the BrowserManager with a browser configuration. @@ -492,11 +483,8 @@ async def start(self): Note: This method should be called in a separate task to avoid blocking the main event loop. """ - self.playwright = await self.get_playwright() - if self.playwright is None: - from playwright.async_api import async_playwright - self.playwright = await async_playwright().start() + self.playwright = await async_playwright().start() if self.config.cdp_url or self.config.use_managed_browser: self.config.use_managed_browser = True @@ -595,7 +583,7 @@ def _build_browser_args(self) -> dict: async def setup_context( self, context: BrowserContext, - crawlerRunConfig: CrawlerRunConfig = None, + crawlerRunConfig: Optional[CrawlerRunConfig] = None, is_default=False, ): """ @@ -621,7 +609,7 @@ async def setup_context( Args: context (BrowserContext): The browser context to set up - crawlerRunConfig (CrawlerRunConfig): Configuration object containing all browser settings + crawlerRunConfig (CrawlerRunConfig or None): Configuration object containing all browser settings is_default (bool): Flag indicating if this is the default context Returns: None @@ -673,9 +661,9 @@ async def setup_context( or crawlerRunConfig.simulate_user or crawlerRunConfig.magic ): - await context.add_init_script(load_js_script("navigator_overrider")) + await context.add_init_script(load_js_script("navigator_overrider")) - async def create_browser_context(self, crawlerRunConfig: CrawlerRunConfig = None): + async def create_browser_context(self, crawlerRunConfig: Optional[CrawlerRunConfig] = None): """ Creates and returns a new browser context with configured settings. Applies text-only mode settings if text_mode is enabled in config. diff --git a/crawl4ai/chunking_strategy.py b/crawl4ai/chunking_strategy.py index f46cb667c..9da2c306e 100644 --- a/crawl4ai/chunking_strategy.py +++ b/crawl4ai/chunking_strategy.py @@ -71,7 +71,6 @@ def __init__(self, **kwargs): """ Initialize the NlpSentenceChunking object. """ - from crawl4ai.le.legacy.model_loader import load_nltk_punkt load_nltk_punkt() def chunk(self, text: str) -> list: diff --git a/crawl4ai/cli.py b/crawl4ai/cli.py index 51477d6b3..0a2b7c6ec 100644 --- a/crawl4ai/cli.py +++ b/crawl4ai/cli.py @@ -15,21 +15,21 @@ from crawl4ai import ( CacheMode, - AsyncWebCrawler, - CrawlResult, - BrowserConfig, + AsyncWebCrawler, + BrowserConfig, CrawlerRunConfig, - LLMExtractionStrategy, + LLMExtractionStrategy, LXMLWebScrapingStrategy, JsonCssExtractionStrategy, JsonXPathExtractionStrategy, - BM25ContentFilter, + BM25ContentFilter, PruningContentFilter, BrowserProfiler, DefaultMarkdownGenerator, LLMConfig ) from crawl4ai.config import USER_SETTINGS +from crawl4ai.models import CrawlResultContainer from litellm import completion from pathlib import Path @@ -139,7 +139,7 @@ def load_config_file(path: Optional[str]) -> dict: except Exception as e: raise click.BadParameter(f'Error loading config file {path}: {str(e)}') -def load_schema_file(path: Optional[str]) -> dict: +def load_schema_file(path: Optional[str]) -> Optional[dict]: if not path: return None return load_config_file(path) @@ -806,8 +806,7 @@ def browser_view_cmd(url: Optional[str]): # Use the CDP URL to launch a new visible window import subprocess - import os - + # Determine the browser command based on platform if sys.platform == "darwin": # macOS browser_cmd = ["/Applications/Google Chrome.app/Contents/MacOS/Google Chrome"] @@ -1069,6 +1068,9 @@ def crawl_cmd(url: str, browser_config: str, crawler_config: str, filter_config: "query": "", "threshold": 0.48 } + else: + filter_conf = {} + if filter_conf["type"] == "bm25": crawler_cfg.markdown_generator = DefaultMarkdownGenerator( content_filter = BM25ContentFilter( @@ -1161,7 +1163,7 @@ def crawl_cmd(url: str, browser_config: str, crawler_config: str, filter_config: crawler_cfg.verbose = config.get("VERBOSE", False) # Run crawler - result : CrawlResult = anyio.run( + result: CrawlResultContainer = anyio.run( run_crawler, url, browser_cfg, @@ -1304,7 +1306,7 @@ def config_set_cmd(key: str, value: str): elif value.lower() in ["false", "no", "0", "n"]: typed_value = False else: - console.print(f"[red]Error: Invalid boolean value. Use 'true' or 'false'.[/red]") + console.print("[red]Error: Invalid boolean value. Use 'true' or 'false'.[/red]") return elif setting["type"] == "string": typed_value = value diff --git a/crawl4ai/components/crawler_monitor.py b/crawl4ai/components/crawler_monitor.py index 49bf9a150..16a1aabf3 100644 --- a/crawl4ai/components/crawler_monitor.py +++ b/crawl4ai/components/crawler_monitor.py @@ -1,36 +1,47 @@ +from __future__ import annotations + import time import uuid import threading +import errno import psutil -from datetime import datetime, timedelta -from typing import Dict, Optional, List -import threading +import sys +from datetime import timedelta +from typing import Dict, Optional from rich.console import Console from rich.layout import Layout from rich.panel import Panel from rich.table import Table from rich.text import Text from rich.live import Live -from rich import box from ..models import CrawlStatus class TerminalUI: """Terminal user interface for CrawlerMonitor using rich library.""" - - def __init__(self, refresh_rate: float = 1.0, max_width: int = 120): + + def __init__(self, monitor: CrawlerMonitor, refresh_rate: float = 1.0, max_width: int = 120): """ Initialize the terminal UI. Args: refresh_rate: How often to refresh the UI (in seconds) max_width: Maximum width of the UI in characters - """ + + Raises: + OSError: If sys.stdin is not a terminal. + """ + if not sys.stdin.isatty(): + # Can't set cbreak mode if stdin is not a terminal, such as running in pytest. + # We check early as the UI loop runs in a separate thread, which would hang + # if we try to set cbreak mode later. + raise OSError(errno.ENOTTY, "stdin is not a terminal") + self.console = Console(width=max_width) self.layout = Layout() self.refresh_rate = refresh_rate self.stop_event = threading.Event() self.ui_thread = None - self.monitor = None # Will be set by CrawlerMonitor + self.monitor: CrawlerMonitor = monitor self.max_width = max_width # Setup layout - vertical layout (top to bottom) @@ -40,10 +51,9 @@ def __init__(self, refresh_rate: float = 1.0, max_width: int = 120): Layout(name="task_details", ratio=1), Layout(name="footer", size=3) # Increased footer size to fit all content ) - - def start(self, monitor): + + def start(self): """Start the UI thread.""" - self.monitor = monitor self.stop_event.clear() self.ui_thread = threading.Thread(target=self._ui_loop) self.ui_thread.daemon = True @@ -358,7 +368,7 @@ def __init__( self, urls_total: int = 0, refresh_rate: float = 1.0, - enable_ui: bool = True, + enable_ui: Optional[bool] = None, max_width: int = 120 ): """ @@ -401,12 +411,13 @@ def __init__( self._lock = threading.RLock() # Terminal UI - self.enable_ui = enable_ui + self.enable_ui = sys.stdin.isatty() if enable_ui is None else enable_ui self.terminal_ui = TerminalUI( - refresh_rate=refresh_rate, + monitor=self, + refresh_rate=refresh_rate, max_width=max_width - ) if enable_ui else None - + ) if self.enable_ui else None + def start(self): """ Start the monitoring session. @@ -421,8 +432,8 @@ def start(self): # Start the terminal UI if self.enable_ui and self.terminal_ui: - self.terminal_ui.start(self) - + self.terminal_ui.start() + def stop(self): """ Stop the monitoring session. diff --git a/crawl4ai/content_filter_strategy.py b/crawl4ai/content_filter_strategy.py index 8d7a51b49..c37d1b0af 100644 --- a/crawl4ai/content_filter_strategy.py +++ b/crawl4ai/content_filter_strategy.py @@ -5,7 +5,7 @@ from typing import List, Tuple, Dict, Optional from rank_bm25 import BM25Okapi from collections import deque -from bs4 import NavigableString, Comment +from bs4.element import NavigableString, Comment from .utils import ( clean_tokens, @@ -31,12 +31,42 @@ from colorama import Fore, Style +# TODO: remove once https://github.com/dorianbrown/rank_bm25/pull/40 has +# been merged and we have updated the dependency. +class BM25OkapiFixed(BM25Okapi): + """BM25 implementation with which avoids zero idf values.""" + + def _calc_idf(self, nd: dict[str, int]) -> None: + """ + Calculates frequencies of terms in documents and in corpus. + This algorithm sets a floor on the idf values to eps * average_idf + """ + # log( (N - n(qi) + 0.5) / (n(qi) + 0.5) + 1) + # where N is the total number of documents in the collection, + # and n(qi) is the number of documents containing qi. + # We use a refactored version of the formula avoiding the division + # to improve performance. + idf_sum: float = 0 + negative_idfs: list[str] = [] + for word, freq in nd.items(): + idf: float = math.log(self.corpus_size + 1) - math.log(freq + 0.5) + self.idf[word] = idf + idf_sum += idf + if idf < 0: + negative_idfs.append(word) + self.average_idf = idf_sum / len(self.idf) + + eps: float = self.epsilon * self.average_idf + for word in negative_idfs: + self.idf[word] = eps + + class RelevantContentFilter(ABC): """Abstract base class for content filtering strategies""" def __init__( self, - user_query: str = None, + user_query: Optional[str] = None, verbose: bool = False, logger: Optional[AsyncLogger] = None, ): @@ -131,46 +161,63 @@ def extract_page_query(self, soup: BeautifulSoup, body: Tag) -> str: query_parts = [] # Title - try: - title = soup.title.string - if title: - query_parts.append(title) - except Exception: - pass + if title := soup.title: + query_parts.append(title.string) - if soup.find("h1"): - query_parts.append(soup.find("h1").get_text()) + # Tags that typically contain meaningful headers. + HEADER_TAGS = {"h1", "h2", "h3", "h4", "h5", "h6", "header"} + for tag in HEADER_TAGS: + if header := soup.find(tag): + query_parts.append(header.get_text()) # Meta tags - temp = "" + empty: bool = True for meta_name in ["keywords", "description"]: meta = soup.find("meta", attrs={"name": meta_name}) - if meta and meta.get("content"): - query_parts.append(meta["content"]) - temp += meta["content"] + if meta and isinstance(meta, Tag): + attrib: Union[str, list[str], None] = meta.get("content") + if attrib: + content: str = ( + attrib if isinstance(attrib, str) else " ".join(attrib) + ) + # TODO: should this be split? + query_parts.append(content.replace(",", " ").replace(" ", " ")) + empty = False # If still empty, grab first significant paragraph - if not temp: - # Find the first tag P thatits text contains more than 50 characters + if empty: + # Find the first tag P that its text contains more than 150 characters for p in body.find_all("p"): - if len(p.get_text()) > 150: - query_parts.append(p.get_text()[:150]) + text: str = p.get_text() + if len(text) > 150: + # Find the last space within the first 150 characters. + if len(text) == 150 or text[150] == " ": + # End of text is at word boundary. + query_parts.append(text[:150]) + else: + last_space_pos: int = text[:150].rfind(" ") + if last_space_pos > 0: + # Include only complete words up to the last space. + query_parts.append(text[:last_space_pos]) + else: + # Fallback if no space found and not at word boundary. + query_parts.append(text[:150]) break return " ".join(filter(None, query_parts)) def extract_text_chunks( - self, body: Tag, min_word_threshold: int = None - ) -> List[Tuple[str, str]]: + self, body: Tag, min_word_threshold: Optional[int] = None + ) -> List[Tuple[int, str, str, Tag]]: """ Extracts text chunks from a BeautifulSoup body element while preserving order. - Returns list of tuples (text, tag_name) for classification. + Returns list of tuples for classification. Args: body: BeautifulSoup Tag object representing the body element Returns: - List of (text, tag_name) tuples + List of (index, chunk, tag_type, tag) tuples """ # Tags to ignore - inline elements that shouldn't break text flow INLINE_TAGS = { @@ -403,7 +450,7 @@ class BM25ContentFilter(RelevantContentFilter): def __init__( self, - user_query: str = None, + user_query: Optional[str] = None, bm25_threshold: float = 1.0, language: str = "english", ): @@ -435,7 +482,7 @@ def __init__( } self.stemmer = stemmer(language) - def filter_content(self, html: str, min_word_threshold: int = None) -> List[str]: + def filter_content(self, html: str, min_word_threshold: Optional[int] = None) -> List[str]: """ Implements content filtering using BM25 algorithm with priority tag handling. @@ -459,27 +506,20 @@ def filter_content(self, html: str, min_word_threshold: int = None) -> List[str] if not soup.body: # Wrap in body tag if missing soup = BeautifulSoup(f"{html}", "lxml") - body = soup.find("body") - query = self.extract_page_query(soup, body) + body = soup.body + if not body: + return [] + query = self.extract_page_query(soup, body) if not query: return [] - # return [self.clean_element(soup)] candidates = self.extract_text_chunks(body, min_word_threshold) - if not candidates: return [] - # Tokenize corpus - # tokenized_corpus = [chunk.lower().split() for _, chunk, _, _ in candidates] - # tokenized_query = query.lower().split() - - # tokenized_corpus = [[ps.stem(word) for word in chunk.lower().split()] - # for _, chunk, _, _ in candidates] - # tokenized_query = [ps.stem(word) for word in query.lower().split()] - + # Tokenize corpus and query tokenized_corpus = [ [self.stemmer.stemWord(word) for word in chunk.lower().split()] for _, chunk, _, _ in candidates @@ -488,22 +528,21 @@ def filter_content(self, html: str, min_word_threshold: int = None) -> List[str] self.stemmer.stemWord(word) for word in query.lower().split() ] - # tokenized_corpus = [[self.stemmer.stemWord(word) for word in tokenize_text(chunk.lower())] - # for _, chunk, _, _ in candidates] - # tokenized_query = [self.stemmer.stemWord(word) for word in tokenize_text(query.lower())] - # Clean from stop words and noise tokenized_corpus = [clean_tokens(tokens) for tokens in tokenized_corpus] tokenized_query = clean_tokens(tokenized_query) - bm25 = BM25Okapi(tokenized_corpus) + bm25 = BM25OkapiFixed(tokenized_corpus) scores = bm25.get_scores(tokenized_query) # Adjust scores with tag weights adjusted_candidates = [] - for score, (index, chunk, tag_type, tag) in zip(scores, candidates): - tag_weight = self.priority_tags.get(tag.name, 1.0) - adjusted_score = score * tag_weight + for score, (index, chunk, _, tag) in zip(scores, candidates): + if score: + tag_weight = self.priority_tags.get(tag.name, 1.0) + adjusted_score = score * tag_weight + else: + adjusted_score = score adjusted_candidates.append((adjusted_score, index, chunk, tag)) # Filter candidates by threshold @@ -546,8 +585,8 @@ class PruningContentFilter(RelevantContentFilter): def __init__( self, - user_query: str = None, - min_word_threshold: int = None, + user_query: Optional[str] = None, + min_word_threshold: Optional[int] = None, threshold_type: str = "fixed", threshold: float = 0.48, ): @@ -563,7 +602,7 @@ def __init__( threshold_type (str): Threshold type for dynamic threshold (default: 'fixed'). threshold (float): Fixed threshold value (default: 0.48). """ - super().__init__(None) + super().__init__(user_query=user_query) self.min_word_threshold = min_word_threshold self.threshold_type = threshold_type self.threshold = threshold @@ -615,7 +654,9 @@ def __init__( "h6": 0.7, } - def filter_content(self, html: str, min_word_threshold: int = None) -> List[str]: + def filter_content( + self, html: str, min_word_threshold: Optional[int] = None + ) -> List[str]: """ Implements content filtering using pruning algorithm with dynamic threshold. @@ -798,8 +839,8 @@ class LLMContentFilter(RelevantContentFilter): def __init__( self, - llm_config: "LLMConfig" = None, - instruction: str = None, + llm_config: Optional[LLMConfig] = None, + instruction: Optional[str] = None, chunk_token_threshold: int = int(1e9), overlap_rate: float = OVERLAP_RATE, word_token_rate: float = WORD_TOKEN_RATE, @@ -813,7 +854,7 @@ def __init__( api_token: Optional[str] = None, base_url: Optional[str] = None, api_base: Optional[str] = None, - extra_args: Dict = None, + extra_args: Optional[Dict] = None, ): super().__init__(None) self.provider = provider diff --git a/crawl4ai/content_scraping_strategy.py b/crawl4ai/content_scraping_strategy.py index a806b045a..3ff8d1c72 100644 --- a/crawl4ai/content_scraping_strategy.py +++ b/crawl4ai/content_scraping_strategy.py @@ -13,8 +13,7 @@ IMPORTANT_ATTRS, SOCIAL_MEDIA_DOMAINS, ) -from bs4 import NavigableString, Comment -from bs4 import PageElement, Tag +from bs4.element import PageElement, Tag, NavigableString, Comment from urllib.parse import urljoin from requests.exceptions import InvalidSchema from .utils import ( @@ -172,6 +171,7 @@ def scrap(self, url: str, html: str, **kwargs) -> ScrapingResult: ], ) + # TODO: this drops errors which are indicated by message. return ScrapingResult( cleaned_html=raw_result.get("cleaned_html", ""), success=raw_result.get("success", False), @@ -192,7 +192,7 @@ async def ascrap(self, url: str, html: str, **kwargs) -> ScrapingResult: Returns: ScrapingResult: A structured result containing the scraped content. """ - return await asyncio.to_thread(self._scrap, url, html, **kwargs) + return await asyncio.to_thread(self.scrap, url, html, **kwargs) def is_data_table(self, table: Tag, **kwargs) -> bool: """ @@ -362,7 +362,19 @@ def flatten_nested_elements(self, node): node.contents = [self.flatten_nested_elements(child) for child in node.contents] return node - def find_closest_parent_with_useful_text(self, tag, **kwargs): + def clean_description(self, description) -> str: + """ + Clean the description text. + + Args: + description (str): The description text to clean. + + Returns: + str: The cleaned description text. + """ + return re.sub(r"\s+", " ", description).strip() + + def find_closest_parent_with_useful_text(self, tag, **kwargs) -> Optional[str]: """ Find the closest parent with useful text. @@ -384,7 +396,7 @@ def find_closest_parent_with_useful_text(self, tag, **kwargs): text_content = current_tag.get_text(separator=" ", strip=True) # Check if the text content has at least word_count_threshold if len(text_content.split()) >= image_description_min_word_threshold: - return text_content + return self.clean_description(text_content) return None def remove_unwanted_attributes( @@ -691,7 +703,8 @@ def _process_element( potential_sources = [ "src", "data-src", - "srcset" "data-lazy-src", + "srcset", + "data-lazy-src", "data-original", ] src = element.get("src", "") @@ -755,7 +768,7 @@ def _process_element( "src": element.get("src"), "alt": element.get("alt"), "type": element.name, - "description": self.find_closest_parent_with_useful_text( + "desc": self.find_closest_parent_with_useful_text( element, **kwargs ), } @@ -767,7 +780,7 @@ def _process_element( "src": source_tag.get("src"), "alt": element.get("alt"), "type": element.name, - "description": self.find_closest_parent_with_useful_text( + "desc": self.find_closest_parent_with_useful_text( element, **kwargs ), } @@ -835,10 +848,10 @@ def _scrap( url: str, html: str, word_count_threshold: int = MIN_WORD_THRESHOLD, - css_selector: str = None, - target_elements: List[str] = None, + css_selector: Optional[str] = None, + target_elements: Optional[List[str]] = None, **kwargs, - ) -> Dict[str, Any]: + ) -> Optional[Dict[str, Any]]: """ Extract content from HTML using BeautifulSoup. @@ -846,7 +859,7 @@ def _scrap( url (str): The URL of the page to scrape. html (str): The HTML content of the page to scrape. word_count_threshold (int): The minimum word count threshold for content extraction. - css_selector (str): The CSS selector to use for content extraction. + css_selector (str or None): The CSS selector to use for content extraction. **kwargs: Additional keyword arguments. Returns: @@ -859,6 +872,9 @@ def _scrap( parser_type = kwargs.get("parser", "lxml") soup = BeautifulSoup(html, parser_type) body = soup.body + if body is None: + return None + base_domain = get_base_domain(url) try: @@ -949,7 +965,7 @@ def _scrap( links["internal"] = list(internal_links_dict.values()) links["external"] = list(external_links_dict.values()) - # # Process images using ThreadPoolExecutor + # Process images using ThreadPoolExecutor imgs = body.find_all("img") media["images"] = [ @@ -974,7 +990,11 @@ def _scrap( body = self.flatten_nested_elements(body) base64_pattern = re.compile(r'data:image/[^;]+;base64,([^"]+)') for img in imgs: + if not isinstance(img, Tag): + continue src = img.get("src", "") + if not src or not isinstance(src, str): + continue if base64_pattern.match(src): # Replace base64 data with empty string img["src"] = base64_pattern.sub("", src) @@ -1119,9 +1139,7 @@ def _process_element( "src": elem.get("src"), "alt": elem.get("alt"), "type": media_type, - "description": self.find_closest_parent_with_useful_text( - elem, **kwargs - ), + "desc": self.find_closest_parent_with_useful_text(elem, **kwargs), } media[f"{media_type}s"].append(media_info) @@ -1159,10 +1177,10 @@ def find_closest_parent_with_useful_text( while current is not None: if ( current.text - and len(current.text_content().split()) + and len(current.text_content().strip().split()) >= image_description_min_word_threshold ): - return current.text_content().strip() + return self.clean_description(current.text_content()) current = current.getparent() return None @@ -1211,9 +1229,9 @@ def process_image( # Score calculation score = 0 - if (width := img.get("width")) and width.isdigit(): + if (width := img.get("width", "")) and width.isdigit(): score += 1 if int(width) > 150 else 0 - if (height := img.get("height")) and height.isdigit(): + if (height := img.get("height", "")) and height.isdigit(): score += 1 if int(height) > 150 else 0 if alt: score += 1 @@ -1476,10 +1494,10 @@ def _scrap( url: str, html: str, word_count_threshold: int = MIN_WORD_THRESHOLD, - css_selector: str = None, - target_elements: List[str] = None, + css_selector: Optional[str] = None, + target_elements: Optional[List[str]] = None, **kwargs, - ) -> Dict[str, Any]: + ) -> Optional[Dict[str, Any]]: if not html: return None @@ -1623,7 +1641,7 @@ def _scrap( # Remove empty elements self.remove_empty_elements_fast(body, 1) - # Remvoe unneeded attributes + # Remove unneeded attributes self.remove_unwanted_attributes_fast( body, keep_data_attributes=kwargs.get("keep_data_attributes", False) ) diff --git a/crawl4ai/crawlers/amazon_product/crawler.py b/crawl4ai/crawlers/amazon_product/crawler.py index 45cc9d6ad..7ad08ec84 100644 --- a/crawl4ai/crawlers/amazon_product/crawler.py +++ b/crawl4ai/crawlers/amazon_product/crawler.py @@ -1,3 +1,5 @@ +import json + from crawl4ai.hub import BaseCrawler __meta__ = { @@ -8,7 +10,7 @@ } class AmazonProductCrawler(BaseCrawler): - async def run(self, url: str, **kwargs) -> str: + async def run(self, url: str = "", **kwargs) -> str: try: self.logger.info(f"Crawling {url}") return '{"product": {"name": "Test Amazon Product"}}' diff --git a/crawl4ai/crawlers/google_search/crawler.py b/crawl4ai/crawlers/google_search/crawler.py index e1288de1d..dc18f626c 100644 --- a/crawl4ai/crawlers/google_search/crawler.py +++ b/crawl4ai/crawlers/google_search/crawler.py @@ -1,11 +1,11 @@ from crawl4ai import BrowserConfig, AsyncWebCrawler, CrawlerRunConfig, CacheMode from crawl4ai.hub import BaseCrawler -from crawl4ai.utils import optimize_html, get_home_folder, preprocess_html_for_schema +from crawl4ai.utils import get_home_folder, preprocess_html_for_schema from crawl4ai.extraction_strategy import JsonCssExtractionStrategy from pathlib import Path import json import os -from typing import Dict +from typing import Dict, Optional class GoogleSearchCrawler(BaseCrawler): @@ -21,7 +21,14 @@ def __init__(self): self.js_script = (Path(__file__).parent / "script.js").read_text() - async def run(self, url="", query: str = "", search_type: str = "text", schema_cache_path = None, **kwargs) -> str: + async def run( + self, + url="", + query: str = "", + search_type: str = "text", + schema_cache_path: Optional[str] = None, + **kwargs, + ) -> str: """Crawl Google Search results for a query""" url = f"https://www.google.com/search?q={query}&gl=sg&hl=en" if search_type == "text" else f"https://www.google.com/search?q={query}&gl=sg&hl=en&tbs=qdr:d&udm=2" if kwargs.get("page_start", 1) > 1: @@ -42,18 +49,21 @@ async def run(self, url="", query: str = "", search_type: str = "text", schema_c result = await crawler.arun(url=url, config=config) if not result.success: - return json.dumps({"error": result.error}) + return json.dumps({"error": result.error_message}) if search_type == "image": - if result.js_execution_result.get("success", False) is False: - return json.dumps({"error": result.js_execution_result.get("error", "Unknown error")}) - if "results" in result.js_execution_result: - image_result = result.js_execution_result['results'][0] - if image_result.get("success", False) is False: - return json.dumps({"error": image_result.get("error", "Unknown error")}) - return json.dumps(image_result["result"], indent=4) + if result.js_execution_result: + if result.js_execution_result.get("success", False) is False: + return json.dumps({"error": result.js_execution_result.get("error", "Unknown error")}) + if "results" in result.js_execution_result: + image_result = result.js_execution_result['results'][0] + if image_result.get("success", False) is False: + return json.dumps({"error": image_result.get("error", "Unknown error")}) + return json.dumps(image_result["result"], indent=4) # For text search, extract structured data + if not result.cleaned_html: + return json.dumps({"error": "No HTML content found"}) schemas = await self._build_schemas(result.cleaned_html, schema_cache_path) extracted = { key: JsonCssExtractionStrategy(schema=schemas[key]).run( @@ -63,13 +73,14 @@ async def run(self, url="", query: str = "", search_type: str = "text", schema_c } return json.dumps(extracted, indent=4) - async def _build_schemas(self, html: str, schema_cache_path: str = None) -> Dict[str, Dict]: + async def _build_schemas( + self, html: str, schema_cache_path: Optional[str] = None + ) -> Dict[str, Dict]: """Build extraction schemas (organic, top stories, etc.)""" home_dir = get_home_folder() if not schema_cache_path else schema_cache_path os.makedirs(f"{home_dir}/schema", exist_ok=True) - # cleaned_html = optimize_html(html, threshold=100) - cleaned_html = preprocess_html_for_schema(html) + cleaned_html = preprocess_html_for_schema(html) organic_schema = None if os.path.exists(f"{home_dir}/schema/organic_schema.json"): diff --git a/crawl4ai/deep_crawling/base_strategy.py b/crawl4ai/deep_crawling/base_strategy.py index e1b3fe6bd..26eaa4a99 100644 --- a/crawl4ai/deep_crawling/base_strategy.py +++ b/crawl4ai/deep_crawling/base_strategy.py @@ -1,11 +1,13 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import AsyncGenerator, Optional, Set, List, Dict -from functools import wraps from contextvars import ContextVar -from ..types import AsyncWebCrawler, CrawlerRunConfig, CrawlResult, RunManyReturn +from functools import wraps +from typing import Any, AsyncGenerator, Callable, Dict, List, Optional, Set, Awaitable +from typing_extensions import Concatenate +from ..types import AsyncWebCrawler, CrawlerRunConfig, CrawlResult +from ..models import CrawlResultContainer class DeepCrawlDecorator: """Decorator that adds deep crawling capability to arun method.""" @@ -14,32 +16,37 @@ class DeepCrawlDecorator: def __init__(self, crawler: AsyncWebCrawler): self.crawler = crawler - def __call__(self, original_arun): + def __call__(self, original_arun: Callable[Concatenate[str, Optional[CrawlerRunConfig], ...], Awaitable[CrawlResultContainer]]): @wraps(original_arun) - async def wrapped_arun(url: str, config: CrawlerRunConfig = None, **kwargs): + async def wrapped_arun( + url: str, config: Optional[CrawlerRunConfig] = None, **kwargs + ) -> CrawlResultContainer: # If deep crawling is already active, call the original method to avoid recursion. if config and config.deep_crawl_strategy and not self.deep_crawl_active.get(): token = self.deep_crawl_active.set(True) # Await the arun call to get the actual result object. - result_obj = await config.deep_crawl_strategy.arun( + result_obj: CrawlResultContainer = await config.deep_crawl_strategy.arun( crawler=self.crawler, start_url=url, config=config ) if config.stream: - async def result_wrapper(): + # Streaming mode. + async def result_wrapper() -> AsyncGenerator[CrawlResult, Any]: try: async for result in result_obj: yield result finally: self.deep_crawl_active.reset(token) - return result_wrapper() - else: - try: - return result_obj - finally: - self.deep_crawl_active.reset(token) - return await original_arun(url, config=config, **kwargs) + return CrawlResultContainer(result_wrapper()) + + # Batch mode. + try: + return result_obj + finally: + self.deep_crawl_active.reset(token) + + return await original_arun(url, config, **kwargs) return wrapped_arun class DeepCrawlStrategy(ABC): @@ -66,8 +73,11 @@ async def _arun_batch( """ pass + # Not async because it returns an AsyncGenerator but does not yield. + # See the following issue from pyright for more information: + # https://github.com/microsoft/pyright/issues/9949 @abstractmethod - async def _arun_stream( + def _arun_stream( self, start_url: str, crawler: AsyncWebCrawler, @@ -84,7 +94,7 @@ async def arun( start_url: str, crawler: AsyncWebCrawler, config: Optional[CrawlerRunConfig] = None, - ) -> RunManyReturn: + ) -> CrawlResultContainer: """ Traverse the given URL using the specified crawler. @@ -94,15 +104,15 @@ async def arun( crawler_run_config (Optional[CrawlerRunConfig]): Crawler configuration. Returns: - Union[CrawlResultT, List[CrawlResultT], AsyncGenerator[CrawlResultT, None]] + CrawlResultContainer """ if config is None: raise ValueError("CrawlerRunConfig must be provided") if config.stream: - return self._arun_stream(start_url, crawler, config) - else: - return await self._arun_batch(start_url, crawler, config) + return CrawlResultContainer(self._arun_stream(start_url, crawler, config)) + + return CrawlResultContainer(await self._arun_batch(start_url, crawler, config)) def __call__(self, start_url: str, crawler: AsyncWebCrawler, config: CrawlerRunConfig): return self.arun(start_url, crawler, config) diff --git a/crawl4ai/deep_crawling/bff_strategy.py b/crawl4ai/deep_crawling/bff_strategy.py index 4811ba141..c46c88e30 100644 --- a/crawl4ai/deep_crawling/bff_strategy.py +++ b/crawl4ai/deep_crawling/bff_strategy.py @@ -5,14 +5,12 @@ from typing import AsyncGenerator, Optional, Set, Dict, List, Tuple from urllib.parse import urlparse -from ..models import TraversalStats +from ..models import TraversalStats, CrawlResultContainer from .filters import FilterChain from .scorers import URLScorer from . import DeepCrawlStrategy -from ..types import AsyncWebCrawler, CrawlerRunConfig, CrawlResult, RunManyReturn - -from math import inf as infinity +from ..types import AsyncWebCrawler, CrawlerRunConfig, CrawlResult # Configurable batch size for processing items from the priority queue BATCH_SIZE = 10 @@ -38,7 +36,7 @@ def __init__( filter_chain: FilterChain = FilterChain(), url_scorer: Optional[URLScorer] = None, include_external: bool = False, - max_pages: int = infinity, + max_pages: int = -1, logger: Optional[logging.Logger] = None, ): self.max_depth = max_depth @@ -62,8 +60,6 @@ async def can_process_url(self, url: str, depth: int) -> bool: raise ValueError("Missing scheme or netloc") if parsed.scheme not in ("http", "https"): raise ValueError("Invalid scheme") - if "." not in parsed.netloc: - raise ValueError("Invalid domain") except Exception as e: self.logger.warning(f"Invalid URL: {url}, error: {e}") return False @@ -92,10 +88,12 @@ async def link_discovery( return # If we've reached the max pages limit, don't discover new links - remaining_capacity = self.max_pages - self._pages_crawled - if remaining_capacity <= 0: - self.logger.info(f"Max pages limit ({self.max_pages}) reached, stopping link discovery") - return + remaining_capacity: int = -1 + if self.max_pages > 0: + remaining_capacity = self.max_pages - self._pages_crawled + if remaining_capacity <= 0: + self.logger.info(f"Max pages limit ({self.max_pages}) reached, stopping link discovery") + return # Retrieve internal links; include external links if enabled. links = result.links.get("internal", []) @@ -106,7 +104,7 @@ async def link_discovery( valid_links = [] for link in links: url = link.get("href") - if url in visited: + if not url or url in visited: continue if not await self.can_process_url(url, new_depth): self.stats.urls_skipped += 1 @@ -115,7 +113,7 @@ async def link_discovery( valid_links.append(url) # If we have more valid links than capacity, limit them - if len(valid_links) > remaining_capacity: + if self.max_pages > 0 and len(valid_links) > remaining_capacity: valid_links = valid_links[:remaining_capacity] self.logger.info(f"Limiting to {remaining_capacity} URLs due to max_pages limit") @@ -144,7 +142,7 @@ async def _arun_best_first( while not queue.empty() and not self._cancel_event.is_set(): # Stop if we've reached the max pages limit - if self._pages_crawled >= self.max_pages: + if self.max_pages > 0 and self._pages_crawled >= self.max_pages: self.logger.info(f"Max pages limit ({self.max_pages}) reached, stopping crawl") break @@ -233,7 +231,7 @@ async def arun( start_url: str, crawler: AsyncWebCrawler, config: Optional[CrawlerRunConfig] = None, - ) -> "RunManyReturn": + ) -> CrawlResultContainer: """ Main entry point for best-first crawling. @@ -243,9 +241,9 @@ async def arun( if config is None: raise ValueError("CrawlerRunConfig must be provided") if config.stream: - return self._arun_stream(start_url, crawler, config) - else: - return await self._arun_batch(start_url, crawler, config) + return CrawlResultContainer(self._arun_stream(start_url, crawler, config)) + + return CrawlResultContainer(await self._arun_batch(start_url, crawler, config)) async def shutdown(self) -> None: """ diff --git a/crawl4ai/deep_crawling/bfs_strategy.py b/crawl4ai/deep_crawling/bfs_strategy.py index 54b72ea34..02bf1a91f 100644 --- a/crawl4ai/deep_crawling/bfs_strategy.py +++ b/crawl4ai/deep_crawling/bfs_strategy.py @@ -8,9 +8,9 @@ from ..models import TraversalStats from .filters import FilterChain from .scorers import URLScorer -from . import DeepCrawlStrategy +from . import DeepCrawlStrategy from ..types import AsyncWebCrawler, CrawlerRunConfig, CrawlResult -from ..utils import normalize_url_for_deep_crawl, efficient_normalize_url_for_deep_crawl +from ..utils import normalize_url_for_deep_crawl from math import inf as infinity class BFSDeepCrawlStrategy(DeepCrawlStrategy): @@ -29,7 +29,7 @@ def __init__( url_scorer: Optional[URLScorer] = None, include_external: bool = False, score_threshold: float = -infinity, - max_pages: int = infinity, + max_pages: int = -1, logger: Optional[logging.Logger] = None, ): self.max_depth = max_depth @@ -54,8 +54,6 @@ async def can_process_url(self, url: str, depth: int) -> bool: raise ValueError("Missing scheme or netloc") if parsed.scheme not in ("http", "https"): raise ValueError("Invalid scheme") - if "." not in parsed.netloc: - raise ValueError("Invalid domain") except Exception as e: self.logger.warning(f"Invalid URL: {url}, error: {e}") return False @@ -85,10 +83,13 @@ async def link_discovery( return # If we've reached the max pages limit, don't discover new links - remaining_capacity = self.max_pages - self._pages_crawled - if remaining_capacity <= 0: - self.logger.info(f"Max pages limit ({self.max_pages}) reached, stopping link discovery") - return + remaining_capacity: int = -1 + if self.max_pages > 0: + remaining_capacity = self.max_pages - self._pages_crawled + if remaining_capacity <= 0: + self.logger.info(f"Max pages limit ({self.max_pages}) reached, stopping link discovery") + return + # Get internal links and, if enabled, external links. links = result.links.get("internal", []) @@ -99,7 +100,10 @@ async def link_discovery( # First collect all valid links for link in links: - url = link.get("href") + url: Optional[str] = link.get("href") + if not url: + continue + # Strip URL fragments to avoid duplicate crawling # base_url = url.split('#')[0] if url else url base_url = normalize_url_for_deep_crawl(url, source_url) @@ -121,7 +125,7 @@ async def link_discovery( valid_links.append((base_url, score)) # If we have more valid links than capacity, sort by score and take the top ones - if len(valid_links) > remaining_capacity: + if self.max_pages > 0 and len(valid_links) > remaining_capacity: if self.url_scorer: # Sort by score in descending order valid_links.sort(key=lambda x: x[1], reverse=True) @@ -163,11 +167,7 @@ async def _arun_batch( # Clone the config to disable deep crawling recursion and enforce batch mode. batch_config = config.clone(deep_crawl_strategy=None, stream=False) batch_results = await crawler.arun_many(urls=urls, config=batch_config) - - # Update pages crawled counter - count only successful crawls - successful_results = [r for r in batch_results if r.success] - self._pages_crawled += len(successful_results) - + for result in batch_results: url = result.url depth = depths.get(url, 0) @@ -179,6 +179,7 @@ async def _arun_batch( # Only discover links from successful crawls if result.success: + self._pages_crawled += 1 # Link discovery will handle the max pages limit internally await self.link_discovery(result, url, depth, visited, next_level, depths) diff --git a/crawl4ai/deep_crawling/filters.py b/crawl4ai/deep_crawling/filters.py index 122be4829..e2385e8d8 100644 --- a/crawl4ai/deep_crawling/filters.py +++ b/crawl4ai/deep_crawling/filters.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import List, Pattern, Set, Union +from typing import List, Optional, Pattern, Set, Union from urllib.parse import urlparse from array import array import re @@ -42,7 +42,7 @@ class URLFilter(ABC): __slots__ = ("name", "stats", "_logger_ref") - def __init__(self, name: str = None): + def __init__(self, name: Optional[str] = None): self.name = name or self.__class__.__name__ self.stats = FilterStats() # Lazy logger initialization using weakref @@ -71,8 +71,8 @@ class FilterChain: __slots__ = ("filters", "stats", "_logger_ref") - def __init__(self, filters: List[URLFilter] = None): - self.filters = tuple(filters or []) # Immutable tuple for speed + def __init__(self, filters: Optional[List[URLFilter]] = None): + self.filters: List[URLFilter] = filters if filters is not None else [] self.stats = FilterStats() self._logger_ref = None @@ -100,6 +100,11 @@ async def apply(self, url: str) -> bool: tasks.append(result) # Collect async tasks elif not result: # Sync rejection self.stats._counters[2] += 1 # Sync rejected + # Cancel remaining tasks + for idx, task in enumerate(tasks): + tasks[idx] = asyncio.create_task(task) + tasks[idx].cancel() + await asyncio.gather(*tasks, return_exceptions=True) return False if tasks: @@ -411,8 +416,8 @@ class DomainFilter(URLFilter): def __init__( self, - allowed_domains: Union[str, List[str]] = None, - blocked_domains: Union[str, List[str]] = None, + allowed_domains: Optional[Union[str, List[str]]] = None, + blocked_domains: Optional[Union[str, List[str]]] = None, ): super().__init__() @@ -572,8 +577,8 @@ class SEOFilter(URLFilter): def __init__( self, threshold: float = 0.65, - keywords: List[str] = None, - weights: Dict[str, float] = None, + keywords: Optional[List[str]] = None, + weights: Optional[Dict[str, float]] = None, ): super().__init__(name="SEOFilter") self.threshold = threshold diff --git a/crawl4ai/deep_crawling/scorers.py b/crawl4ai/deep_crawling/scorers.py index 1cd9f3e13..69413eb6e 100644 --- a/crawl4ai/deep_crawling/scorers.py +++ b/crawl4ai/deep_crawling/scorers.py @@ -1,13 +1,11 @@ from abc import ABC, abstractmethod from typing import List, Dict, Optional -from dataclasses import dataclass -from urllib.parse import urlparse, unquote import re -import logging from functools import lru_cache from array import array import ctypes import platform + PLATFORM = platform.system() # Pre-computed scores for common year differences @@ -25,13 +23,26 @@ class ScoringStats: __slots__ = ('_urls_scored', '_total_score', '_min_score', '_max_score') - - def __init__(self): - self._urls_scored = 0 - self._total_score = 0.0 - self._min_score = None # Lazy initialization - self._max_score = None - + + def __init__( + self, + urls_scored: int = 0, + total_score: float = 0.0, + min_score: Optional[float] = None, + max_score: Optional[float] = None, + ): + """Initialize scoring statistics. + Args: + urls_scored (int): Number of URLs scored + total_score (float): Sum of all scores + min_score (float or None): Minimum score observed + max_score (float or None): Maximum score observed + """ + self._urls_scored = urls_scored + self._total_score = total_score + self._min_score = min_score + self._max_score = max_score + def update(self, score: float) -> None: """Optimized update with minimal operations""" self._urls_scored += 1 @@ -62,12 +73,12 @@ def get_max(self) -> float: return self._max_score class URLScorer(ABC): __slots__ = ('_weight', '_stats') - - def __init__(self, weight: float = 1.0): + + def __init__(self, weight: float = 1.0, stats: Optional[ScoringStats] = None): # Store weight directly as float32 for memory efficiency self._weight = ctypes.c_float(weight).value - self._stats = ScoringStats() - + self._stats = stats or ScoringStats() + @abstractmethod def _calculate_score(self, url: str) -> float: """Calculate raw score for URL.""" @@ -159,9 +170,9 @@ def score(self, url: str) -> float: class KeywordRelevanceScorer(URLScorer): __slots__ = ('_weight', '_stats', '_keywords', '_case_sensitive') - - def __init__(self, keywords: List[str], weight: float = 1.0, case_sensitive: bool = False): - super().__init__(weight=weight) + + def __init__(self, keywords: List[str], weight: float = 1.0, case_sensitive: bool = False, stats: Optional[ScoringStats] = None): + super().__init__(weight=weight, stats=stats) self._case_sensitive = case_sensitive # Pre-process keywords once self._keywords = [k if case_sensitive else k.lower() for k in keywords] diff --git a/crawl4ai/docker_client.py b/crawl4ai/docker_client.py index f4816eb5b..c0c74920d 100644 --- a/crawl4ai/docker_client.py +++ b/crawl4ai/docker_client.py @@ -1,11 +1,10 @@ -from typing import List, Optional, Union, AsyncGenerator, Dict, Any +from typing import List, Optional, AsyncGenerator, Dict, Any, Self import httpx import json -from urllib.parse import urljoin import asyncio from .async_configs import BrowserConfig, CrawlerRunConfig -from .models import CrawlResult +from .models import CrawlResult, CrawlResultContainer from .async_logger import AsyncLogger, LogLevel @@ -33,12 +32,14 @@ def __init__( timeout: float = 30.0, verify_ssl: bool = True, verbose: bool = True, - log_file: Optional[str] = None + log_file: Optional[str] = None, + transport: Optional[httpx.AsyncBaseTransport] = None, ): - self.base_url = base_url.rstrip('/') self.timeout = timeout self.logger = AsyncLogger(log_file=log_file, log_level=LogLevel.DEBUG, verbose=verbose) self._http_client = httpx.AsyncClient( + base_url=base_url.rstrip("/"), + transport=transport, timeout=timeout, verify=verify_ssl, headers={"Content-Type": "application/json"} @@ -47,10 +48,9 @@ def __init__( async def authenticate(self, email: str) -> None: """Authenticate with the server and store the token.""" - url = urljoin(self.base_url, "/token") try: self.logger.info(f"Authenticating with email: {email}", tag="AUTH") - response = await self._http_client.post(url, json={"email": email}) + response = await self._http_client.post("/token", json={"email": email}) response.raise_for_status() data = response.json() self._token = data["access_token"] @@ -64,26 +64,31 @@ async def authenticate(self, email: str) -> None: async def _check_server(self) -> None: """Check if server is reachable, raising an error if not.""" try: - await self._http_client.get(urljoin(self.base_url, "/health")) - self.logger.success(f"Connected to {self.base_url}", tag="READY") + await self._http_client.get("/health") + self.logger.success( + f"Connected to {self._http_client.base_url}", tag="READY" + ) except httpx.RequestError as e: self.logger.error(f"Server unreachable: {str(e)}", tag="ERROR") raise ConnectionError(f"Cannot connect to server: {str(e)}") - def _prepare_request(self, urls: List[str], browser_config: Optional[BrowserConfig] = None, - crawler_config: Optional[CrawlerRunConfig] = None) -> Dict[str, Any]: + def _prepare_request( + self, + urls: List[str], + browser_config: Optional[BrowserConfig] = None, + crawler_config: Optional[CrawlerRunConfig] = None, + ) -> Dict[str, Any]: """Prepare request data from configs.""" return { "urls": urls, "browser_config": browser_config.dump() if browser_config else {}, - "crawler_config": crawler_config.dump() if crawler_config else {} + "crawler_config": crawler_config.dump() if crawler_config else {}, } async def _request(self, method: str, endpoint: str, **kwargs) -> httpx.Response: """Make an HTTP request with error handling.""" - url = urljoin(self.base_url, endpoint) try: - response = await self._http_client.request(method, url, **kwargs) + response = await self._http_client.request(method, endpoint, **kwargs) response.raise_for_status() return response except httpx.TimeoutException as e: @@ -100,8 +105,8 @@ async def crawl( self, urls: List[str], browser_config: Optional[BrowserConfig] = None, - crawler_config: Optional[CrawlerRunConfig] = None - ) -> Union[CrawlResult, List[CrawlResult], AsyncGenerator[CrawlResult, None]]: + crawler_config: Optional[CrawlerRunConfig] = None, + ) -> CrawlResultContainer: """Execute a crawl operation.""" if not self._token: raise Crawl4aiClientError("Authentication required. Call authenticate() first.") @@ -114,8 +119,20 @@ async def crawl( if is_streaming: async def stream_results() -> AsyncGenerator[CrawlResult, None]: - async with self._http_client.stream("POST", f"{self.base_url}/crawl/stream", json=data) as response: - response.raise_for_status() + async with self._http_client.stream("POST", "/crawl/stream", json=data) as response: + if response.status_code != httpx.codes.OK: + await response.aread() + response_data: dict[str, Any] = response.json() + yield CrawlResult( + url=data.get("url", "unknown"), + html="", + success=False, + error_message=str( + response_data.get("detail", "Unknown error") + ), + ) + return + async for line in response.aiter_lines(): if line.strip(): result = json.loads(line) @@ -127,8 +144,9 @@ async def stream_results() -> AsyncGenerator[CrawlResult, None]: continue else: yield CrawlResult(**result) - return stream_results() - + + return CrawlResultContainer(stream_results()) + response = await self._request("POST", "/crawl", json=data) result_data = response.json() if not result_data.get("success", False): @@ -136,7 +154,7 @@ async def stream_results() -> AsyncGenerator[CrawlResult, None]: results = [CrawlResult(**r) for r in result_data.get("results", [])] self.logger.success(f"Crawl completed with {len(results)} results", tag="CRAWL") - return results[0] if len(results) == 1 else results + return CrawlResultContainer(results) async def get_schema(self) -> Dict[str, Any]: """Retrieve configuration schemas.""" @@ -150,7 +168,7 @@ async def close(self) -> None: self.logger.info("Closing client", tag="CLOSE") await self._http_client.aclose() - async def __aenter__(self) -> "Crawl4aiDockerClient": + async def __aenter__(self) -> Self: return self async def __aexit__(self, exc_type: Optional[type], exc_val: Optional[Exception], exc_tb: Optional[Any]) -> None: diff --git a/crawl4ai/extraction_strategy.py b/crawl4ai/extraction_strategy.py index bf4825cc0..33b27c781 100644 --- a/crawl4ai/extraction_strategy.py +++ b/crawl4ai/extraction_strategy.py @@ -34,13 +34,14 @@ calculate_batch_size ) -from .types import LLMConfig, create_llm_config +from .types import LLMConfig from functools import partial import numpy as np import re from bs4 import BeautifulSoup -from lxml import html, etree +from lxml import etree +from lxml.html import fromstring class ExtractionStrategy(ABC): @@ -498,9 +499,9 @@ class LLMExtractionStrategy(ExtractionStrategy): } def __init__( self, - llm_config: 'LLMConfig' = None, - instruction: str = None, - schema: Dict = None, + llm_config: Optional[LLMConfig] = None, + instruction: Optional[str] = None, + schema: Optional[Dict] = None, extraction_type="block", chunk_token_threshold=CHUNK_TOKEN_THRESHOLD, overlap_rate=OVERLAP_RATE, @@ -512,8 +513,8 @@ def __init__( # Deprecated arguments provider: str = DEFAULT_PROVIDER, api_token: Optional[str] = None, - base_url: str = None, - api_base: str = None, + base_url: Optional[str] = None, + api_base: Optional[str] = None, **kwargs, ): """ @@ -884,7 +885,7 @@ def extract( return results @abstractmethod - def _parse_html(self, html_content: str): + def _parse_html(self, html_content: str) -> Any: """Parse HTML content into appropriate format""" pass @@ -1080,13 +1081,13 @@ def _get_element_attribute(self, element, attribute: str): @staticmethod def generate_schema( html: str, - schema_type: str = "CSS", # or XPATH - query: str = None, - target_json_example: str = None, - llm_config: 'LLMConfig' = create_llm_config(), - provider: str = None, - api_token: str = None, - **kwargs + schema_type: str = "CSS", # or XPATH + query: Optional[str] = None, + target_json_example: Optional[str] = None, + llm_config: Optional[LLMConfig] = None, + provider: Optional[str] = None, + api_token: Optional[str] = None, + **kwargs, ) -> dict: """ Generate extraction schema from HTML content and optional query. @@ -1167,13 +1168,21 @@ def generate_schema( try: # Call LLM with backoff handling + base_url: str = "" + if llm_config: + provider = llm_config.provider + api_token = llm_config.api_token + base_url = llm_config.base_url or "" + response = perform_completion_with_backoff( - provider=llm_config.provider, - prompt_with_variables="\n\n".join([system_message["content"], user_message["content"]]), - json_response = True, - api_token=llm_config.api_token, - base_url=llm_config.base_url, - extra_args=kwargs + provider=provider, + prompt_with_variables="\n\n".join( + [system_message["content"], user_message["content"]] + ), + json_response=True, + api_token=api_token, + base_url=base_url, + extra_args=kwargs, ) # Extract and return schema @@ -1209,7 +1218,6 @@ def __init__(self, schema: Dict[str, Any], **kwargs): super().__init__(schema, **kwargs) def _parse_html(self, html_content: str): - # return BeautifulSoup(html_content, "html.parser") return BeautifulSoup(html_content, "lxml") def _get_base_elements(self, parsed_html, selector: str): @@ -1625,7 +1633,7 @@ def __init__(self, schema: Dict[str, Any], **kwargs): super().__init__(schema, **kwargs) def _parse_html(self, html_content: str): - return html.fromstring(html_content) + return fromstring(html_content) def _get_base_elements(self, parsed_html, selector: str): return parsed_html.xpath(selector) diff --git a/crawl4ai/hub.py b/crawl4ai/hub.py index 75056df77..e81aef680 100644 --- a/crawl4ai/hub.py +++ b/crawl4ai/hub.py @@ -1,6 +1,6 @@ # crawl4ai/hub.py from abc import ABC, abstractmethod -from typing import Dict, Type, Union +from typing import Dict, Type, Union, Any import logging import importlib from pathlib import Path @@ -10,9 +10,12 @@ class BaseCrawler(ABC): + meta: dict[str, Any] = {} + def __init__(self): self.logger = logging.getLogger(self.__class__.__name__) - + self._meta: dict[str, Any] = {} + @abstractmethod async def run(self, url: str = "", **kwargs) -> str: """ @@ -57,9 +60,16 @@ def _discover_crawlers(cls): @classmethod def _maybe_register_crawler(cls, obj, name: str): """Brilliant one-liner registration""" - if isinstance(obj, type) and issubclass(obj, BaseCrawler) and obj != BaseCrawler: - module = importlib.import_module(obj.__module__) - obj.meta = getattr(module, "__meta__", {}) + if ( + isinstance(obj, type) + and issubclass(obj, BaseCrawler) + and obj != BaseCrawler + ): + if hasattr(obj, "__meta__"): + obj.meta = obj.__meta__ # pyright: ignore[reportAttributeAccessIssue] + else: + module = importlib.import_module(obj.__module__) + obj.meta = getattr(module, "__meta__", {}) cls._crawlers[name] = obj @classmethod diff --git a/crawl4ai/legacy/crawler_strategy.py b/crawl4ai/legacy/crawler_strategy.py index 34e20ecd8..225d7cf19 100644 --- a/crawl4ai/legacy/crawler_strategy.py +++ b/crawl4ai/legacy/crawler_strategy.py @@ -10,8 +10,8 @@ # from webdriver_manager.chrome import ChromeDriverManager # from urllib3.exceptions import MaxRetryError -from .config import * -import logging, time +import logging +import time import base64 from PIL import Image, ImageDraw, ImageFont from io import BytesIO @@ -19,7 +19,7 @@ import requests import os from pathlib import Path -from .utils import * +from ..utils import sanitize_input_encode, wrap_text logger = logging.getLogger("selenium.webdriver.remote.remote_connection") logger.setLevel(logging.WARNING) @@ -45,7 +45,7 @@ def crawl(self, url: str, **kwargs) -> str: pass @abstractmethod - def take_screenshot(self, save_path: str): + def take_screenshot(self) -> str: pass @abstractmethod diff --git a/crawl4ai/legacy/database.py b/crawl4ai/legacy/database.py index 815b6b051..2fac7614c 100644 --- a/crawl4ai/legacy/database.py +++ b/crawl4ai/legacy/database.py @@ -24,7 +24,9 @@ def init_db(): media TEXT DEFAULT "{}", links TEXT DEFAULT "{}", metadata TEXT DEFAULT "{}", - screenshot TEXT DEFAULT "" + screenshot TEXT DEFAULT "", + response_headers TEXT DEFAULT "{}", -- Non-legacy field + downloaded_files TEXT DEFAULT "{}" -- Non-legacy field ) """ ) @@ -53,7 +55,7 @@ def check_db_path(): def get_cached_url( url: str, -) -> Optional[Tuple[str, str, str, str, str, str, str, bool, str]]: +) -> Optional[Tuple[str, str, str, str, str, bool, str, str, str, str]]: check_db_path() try: conn = sqlite3.connect(DB_PATH) diff --git a/crawl4ai/legacy/docs_manager.py b/crawl4ai/legacy/docs_manager.py index 9a6096a5e..37840d8d9 100644 --- a/crawl4ai/legacy/docs_manager.py +++ b/crawl4ai/legacy/docs_manager.py @@ -1,14 +1,19 @@ -import requests import shutil from pathlib import Path +from typing import Final + +import requests + from crawl4ai.async_logger import AsyncLogger -from crawl4ai.llmtxt import AsyncLLMTextManager +from .llmtxt import AsyncLLMTextManager + +GIT_DOCS: Final = "https://api.github.com/repos/unclecode/crawl4ai/contents/docs" class DocsManager: def __init__(self, logger=None): self.docs_dir = Path.home() / ".crawl4ai" / "docs" - self.local_docs = Path(__file__).parent.parent / "docs" / "llm.txt" + self.local_docs = Path(__file__).parent.parent.parent / "docs" self.docs_dir.mkdir(parents=True, exist_ok=True) self.logger = logger or AsyncLogger(verbose=True) self.llm_text = AsyncLLMTextManager(self.docs_dir, self.logger) @@ -21,40 +26,54 @@ async def ensure_docs_exist(self): async def fetch_docs(self) -> bool: """Copy from local docs or download from GitHub""" try: - # Try local first - if self.local_docs.exists() and ( - any(self.local_docs.glob("*.md")) - or any(self.local_docs.glob("*.tokens")) - ): - # Empty the local docs directory - for file_path in self.docs_dir.glob("*.md"): - file_path.unlink() - # for file_path in self.docs_dir.glob("*.tokens"): - # file_path.unlink() - for file_path in self.local_docs.glob("*.md"): - shutil.copy2(file_path, self.docs_dir / file_path.name) - # for file_path in self.local_docs.glob("*.tokens"): - # shutil.copy2(file_path, self.docs_dir / file_path.name) - return True - - # Fallback to GitHub - response = requests.get( - "https://api.github.com/repos/unclecode/crawl4ai/contents/docs/llm.txt", - headers={"Accept": "application/vnd.github.v3+json"}, - ) - response.raise_for_status() - - for item in response.json(): - if item["type"] == "file" and item["name"].endswith(".md"): - content = requests.get(item["download_url"]).text - with open(self.docs_dir / item["name"], "w", encoding="utf-8") as f: - f.write(content) - return True + # Remove existing markdown files. + dirs: set[Path] = set() + for file_path in self.docs_dir.glob("**/*.md"): + dirs.add(file_path.parent) + file_path.unlink() + + # Remove empty directories. + for dir_path in sorted(dirs, reverse=True): + if not any(dir_path.iterdir()): + dir_path.rmdir() + if self.local_docs.exists() and (any(self.local_docs.glob("**/*.md"))): + # Copy from local docs. + for file_path in self.local_docs.glob("**/*.md"): + rel_path = file_path.relative_to(self.local_docs) + dest_path = self.docs_dir / rel_path + dest_path.parent.mkdir(parents=True, exist_ok=True) + shutil.copy2(file_path, dest_path) + else: + # Download from GitHub. + self.download_docs(GIT_DOCS) + + return True except Exception as e: self.logger.error(f"Failed to fetch docs: {str(e)}") raise + def download_docs(self, url: str): + """Download docs from GitHub""" + + response = requests.get( + url, + headers={"Accept": "application/vnd.github.v3+json"}, + ) + response.raise_for_status() + + for item in response.json(): + if item["type"] == "dir": + self.download_docs(item["url"]) + elif item["type"] == "file" and item["name"].endswith(".md"): + path: str = item["path"] + dest_path: Path = self.docs_dir / path.removeprefix("docs/") + dest_path.parent.mkdir(parents=True, exist_ok=True) + + content = requests.get(item["download_url"]).text + with open(dest_path, "w", encoding="utf-8") as f: + f.write(content) + def list(self) -> list[str]: """List available topics""" names = [file_path.stem for file_path in self.docs_dir.glob("*.md")] diff --git a/crawl4ai/legacy/llmtxt.py b/crawl4ai/legacy/llmtxt.py index 302564165..fc7b8883d 100644 --- a/crawl4ai/legacy/llmtxt.py +++ b/crawl4ai/legacy/llmtxt.py @@ -12,7 +12,7 @@ from nltk.corpus import stopwords from nltk.stem import WordNetLemmatizer from litellm import batch_completion -from .async_logger import AsyncLogger +from ..async_logger import AsyncLogger import litellm import pickle import hashlib # <--- ADDED for file-hash @@ -39,7 +39,7 @@ def __init__( batch_size: int = 3, ) -> None: self.docs_dir = docs_dir - self.logger = logger + self.logger = logger or AsyncLogger() self.max_concurrent_calls = max_concurrent_calls self.batch_size = batch_size self.bm25_index = None diff --git a/crawl4ai/legacy/web_crawler.py b/crawl4ai/legacy/web_crawler.py index a92ae6ddb..f18e91f9e 100644 --- a/crawl4ai/legacy/web_crawler.py +++ b/crawl4ai/legacy/web_crawler.py @@ -1,18 +1,19 @@ -import os, time +import os +import time os.environ["TOKENIZERS_PARALLELISM"] = "false" from pathlib import Path -from .models import UrlModel, CrawlResult +from ..models import UrlModel, CrawlResult from .database import init_db, get_cached_url, cache_url -from .utils import * -from .chunking_strategy import * -from .extraction_strategy import * -from .crawler_strategy import * -from typing import List +from ..utils import InvalidCSSSelectorError, format_html, sanitize_input_encode +from ..chunking_strategy import ChunkingStrategy, RegexChunking +from ..extraction_strategy import ExtractionStrategy, NoExtractionStrategy +from .crawler_strategy import CrawlerStrategy, LocalSeleniumCrawlerStrategy +from typing import List, Optional from concurrent.futures import ThreadPoolExecutor -from .content_scraping_strategy import WebScrapingStrategy -from .config import * +from ..content_scraping_strategy import WebScrapingStrategy +from ..config import DEFAULT_PROVIDER, IMAGE_DESCRIPTION_MIN_WORD_THRESHOLD, MIN_WORD_THRESHOLD import warnings import json @@ -25,7 +26,7 @@ class WebCrawler: def __init__( self, - crawler_strategy: CrawlerStrategy = None, + crawler_strategy: Optional[CrawlerStrategy] = None, always_by_pass_cache: bool = False, verbose: bool = False, ): @@ -57,13 +58,13 @@ def fetch_page( self, url_model: UrlModel, provider: str = DEFAULT_PROVIDER, - api_token: str = None, + api_token: Optional[str] = None, extract_blocks_flag: bool = True, word_count_threshold=MIN_WORD_THRESHOLD, - css_selector: str = None, + css_selector: Optional[str] = None, screenshot: bool = False, use_cached_html: bool = False, - extraction_strategy: ExtractionStrategy = None, + extraction_strategy: Optional[ExtractionStrategy] = None, chunking_strategy: ChunkingStrategy = RegexChunking(), **kwargs, ) -> CrawlResult: @@ -83,13 +84,13 @@ def fetch_pages( self, url_models: List[UrlModel], provider: str = DEFAULT_PROVIDER, - api_token: str = None, + api_token: Optional[str] = None, extract_blocks_flag: bool = True, word_count_threshold=MIN_WORD_THRESHOLD, use_cached_html: bool = False, - css_selector: str = None, + css_selector: Optional[str] = None, screenshot: bool = False, - extraction_strategy: ExtractionStrategy = None, + extraction_strategy: Optional[ExtractionStrategy] = None, chunking_strategy: ChunkingStrategy = RegexChunking(), **kwargs, ) -> List[CrawlResult]: @@ -122,34 +123,38 @@ def run( self, url: str, word_count_threshold=MIN_WORD_THRESHOLD, - extraction_strategy: ExtractionStrategy = None, + extraction_strategy: Optional[ExtractionStrategy] = None, chunking_strategy: ChunkingStrategy = RegexChunking(), bypass_cache: bool = False, - css_selector: str = None, + css_selector: Optional[str] = None, screenshot: bool = False, - user_agent: str = None, + user_agent: Optional[str] = None, verbose=True, **kwargs, ) -> CrawlResult: try: extraction_strategy = extraction_strategy or NoExtractionStrategy() - extraction_strategy.verbose = verbose if not isinstance(extraction_strategy, ExtractionStrategy): raise ValueError("Unsupported extraction strategy") + extraction_strategy.verbose = verbose + if not isinstance(chunking_strategy, ChunkingStrategy): raise ValueError("Unsupported chunking strategy") word_count_threshold = max(word_count_threshold, MIN_WORD_THRESHOLD) cached = None - screenshot_data = None - extracted_content = None + screenshot_data: str = "" + extracted_content: str = "" if not bypass_cache and not self.always_by_pass_cache: cached = get_cached_url(url) if kwargs.get("warmup", True) and not self.ready: - return None + raise ValueError( + "WebCrawler is not ready. Please call the warmup method before crawling." + ) + html: str = "" if cached: html = sanitize_input_encode(cached[1]) extracted_content = sanitize_input_encode(cached[4]) @@ -200,8 +205,8 @@ def process_html( word_count_threshold: int, extraction_strategy: ExtractionStrategy, chunking_strategy: ChunkingStrategy, - css_selector: str, - screenshot: bool, + css_selector: Optional[str], + screenshot: str, verbose: bool, is_cached: bool, **kwargs, @@ -240,42 +245,19 @@ def process_html( except InvalidCSSSelectorError as e: raise ValueError(str(e)) - cleaned_html = sanitize_input_encode(result.get("cleaned_html", "")) - markdown = sanitize_input_encode(result.get("markdown", "")) - media = result.get("media", []) - links = result.get("links", []) - metadata = result.get("metadata", {}) - - if extracted_content is None: - if verbose: - print( - f"[LOG] 🔥 Extracting semantic blocks for {url}, Strategy: {extraction_strategy.name}" - ) - - sections = chunking_strategy.chunk(markdown) - extracted_content = extraction_strategy.run(url, sections) - extracted_content = json.dumps( - extracted_content, indent=4, default=str, ensure_ascii=False - ) - - if verbose: - print( - f"[LOG] 🚀 Extraction done for {url}, time taken: {time.time() - t:.2f} seconds." - ) - - screenshot = None if not screenshot else screenshot + cleaned_html = sanitize_input_encode(result.cleaned_html) if not is_cached: cache_url( url, html, cleaned_html, - markdown, + "", extracted_content, True, - json.dumps(media), - json.dumps(links), - json.dumps(metadata), + json.dumps(result.media.model_dump()), + json.dumps(result.links.model_dump()), + json.dumps(result.metadata), screenshot=screenshot, ) @@ -283,10 +265,9 @@ def process_html( url=url, html=html, cleaned_html=format_html(cleaned_html), - markdown=markdown, - media=media, - links=links, - metadata=metadata, + media=result.media.model_dump(), + links=result.links.model_dump(), + metadata=result.metadata, screenshot=screenshot, extracted_content=extracted_content, success=True, diff --git a/crawl4ai/models.py b/crawl4ai/models.py index f9551c1ae..77454632d 100644 --- a/crawl4ai/models.py +++ b/crawl4ai/models.py @@ -1,5 +1,6 @@ -from pydantic import BaseModel, HttpUrl, PrivateAttr -from typing import List, Dict, Optional, Callable, Awaitable, Union, Any +from __future__ import annotations +from pydantic import BaseModel, HttpUrl, PrivateAttr, Field +from typing import List, Dict, Optional, Callable, Awaitable, Union, Any, AsyncGenerator, Iterator, AsyncIterator from enum import Enum from dataclasses import dataclass from .ssl_certificate import SSLCertificate @@ -21,7 +22,7 @@ class DomainState: class CrawlerTaskResult: task_id: str url: str - result: "CrawlResult" + result: CrawlResultContainer memory_usage: float peak_memory: float start_time: Union[datetime, float] @@ -42,26 +43,6 @@ class CrawlStatus(Enum): FAILED = "FAILED" -# @dataclass -# class CrawlStats: -# task_id: str -# url: str -# status: CrawlStatus -# start_time: Optional[datetime] = None -# end_time: Optional[datetime] = None -# memory_usage: float = 0.0 -# peak_memory: float = 0.0 -# error_message: str = "" - -# @property -# def duration(self) -> str: -# if not self.start_time: -# return "0:00" -# end = self.end_time or datetime.now() -# duration = end - self.start_time -# return str(timedelta(seconds=int(duration.total_seconds()))) - - @dataclass class CrawlStats: task_id: str @@ -91,9 +72,9 @@ def duration(self) -> str: # Convert end_time to datetime if it's a float if isinstance(end, float): end = datetime.fromtimestamp(end) - - duration = end - start - return str(timedelta(seconds=int(duration.total_seconds()))) + + duration = end - start # pyright: ignore[reportOperatorIssue] + return str(timedelta(seconds=int(duration.total_seconds()))) # pyright: ignore[reportAttributeAccessIssue] class DisplayMode(Enum): @@ -119,9 +100,9 @@ class UrlModel(BaseModel): class MarkdownGenerationResult(BaseModel): - raw_markdown: str - markdown_with_citations: str - references_markdown: str + raw_markdown: str = "" + markdown_with_citations: str = "" + references_markdown: str = "" fit_markdown: Optional[str] = None fit_html: Optional[str] = None @@ -152,8 +133,8 @@ class CrawlResult(BaseModel): html: str success: bool cleaned_html: Optional[str] = None - media: Dict[str, List[Dict]] = {} - links: Dict[str, List[Dict]] = {} + media: Dict[str, List[Dict]] = Field(default_factory=dict) + links: Dict[str, List[Dict]] = Field(default_factory=dict) downloaded_files: Optional[List[str]] = None js_execution_result: Optional[Dict[str, Any]] = None screenshot: Optional[str] = None @@ -172,26 +153,67 @@ class CrawlResult(BaseModel): class Config: arbitrary_types_allowed = True -# NOTE: The StringCompatibleMarkdown class, custom __init__ method, property getters/setters, -# and model_dump override all exist to support a smooth transition from markdown as a string -# to markdown as a MarkdownGenerationResult object, while maintaining backward compatibility. -# -# This allows code that expects markdown to be a string to continue working, while also -# providing access to the full MarkdownGenerationResult object's properties. -# -# The markdown_v2 property is deprecated and raises an error directing users to use markdown. -# -# When backward compatibility is no longer needed in future versions, this entire mechanism -# can be simplified to a standard field with no custom accessors or serialization logic. - - def __init__(self, **data): - markdown_result = data.pop('markdown', None) - super().__init__(**data) - if markdown_result is not None: + # NOTE: The StringCompatibleMarkdown class, custom __init__ method, property getters/setters, + # and model_dump override all exist to support a smooth transition from markdown as a string + # to markdown as a MarkdownGenerationResult object, while maintaining backward compatibility. + # + # This allows code that expects markdown to be a string to continue working, while also + # providing access to the full MarkdownGenerationResult object's properties. + # + # The markdown_v2 property is deprecated and raises an error directing users to use markdown. + # + # When backward compatibility is no longer needed in future versions, this entire mechanism + # can be simplified to a standard field with no custom accessors or serialization logic. + + def __init__( + self, + url: str, + html: str, + success: bool, + cleaned_html: Optional[str] = None, + media: Optional[Dict[str, List[Dict]]] = None, + links: Optional[Dict[str, List[Dict]]] = None, + downloaded_files: Optional[List[str]] = None, + js_execution_result: Optional[Dict[str, Any]] = None, + screenshot: Optional[str] = None, + pdf: Optional[bytes] = None, + markdown: Optional[Union[MarkdownGenerationResult, dict]] = None, + extracted_content: Optional[str] = None, + metadata: Optional[dict] = None, + error_message: Optional[str] = None, + session_id: Optional[str] = None, + response_headers: Optional[dict] = None, + status_code: Optional[int] = None, + ssl_certificate: Optional[SSLCertificate] = None, + dispatch_result: Optional[DispatchResult] = None, + redirected_url: Optional[str] = None + ): + super().__init__( + url=url, + html=html, + success=success, + cleaned_html=cleaned_html, + media=media if media is not None else {}, + links=links if links is not None else {}, + downloaded_files=downloaded_files, + js_execution_result=js_execution_result, + screenshot=screenshot, + pdf=pdf, + extracted_content=extracted_content, + metadata=metadata, + error_message=error_message, + session_id=session_id, + response_headers=response_headers, + status_code=status_code, + ssl_certificate=ssl_certificate, + dispatch_result=dispatch_result, + redirected_url=redirected_url + ) + if markdown is not None: self._markdown = ( - MarkdownGenerationResult(**markdown_result) - if isinstance(markdown_result, dict) - else markdown_result + MarkdownGenerationResult(**markdown) + if isinstance(markdown, dict) + else markdown ) @property @@ -273,16 +295,168 @@ def model_dump(self, *args, **kwargs): result["markdown"] = self._markdown.model_dump() return result +CrawlResultsT = Union[ + CrawlResult, List[CrawlResult], AsyncGenerator[CrawlResult, None] +] + +class CrawlResultContainer(CrawlResult): + """A container class for crawl results. + + Provides a consistent interface for synchronous and asynchronous iteration + as well as direct access to fields of first result and the length of the + results. + """ + # We use private attributes and a property for source to simplify the + # implementation of __getattribute__. + _source: CrawlResultsT = PrivateAttr() + _results: List[CrawlResult] = PrivateAttr() + + def __init__( + self, + results: CrawlResultsT, + ) -> None: + result_list: List[CrawlResult] + if isinstance(results, AsyncGenerator): + result_list = [] + elif isinstance(results, List): + result_list = results + else: + result_list = [results] + + if len(result_list) == 0: + super().__init__(url="", html="", success=False) + else: + super().__init__(**result_list[0].model_dump()) + + self._source = results + self._results = result_list + + @property + def source(self) -> CrawlResultsT: + """Returns the source of the crawl results. + + :return: The source of the crawl results. + :rtype: CrawlResultsT + """ + return self._source + + def _raise_if_async_generator(self): + """Raises a TypeError if the source is an AsyncGenerator. + + This is to prevent synchronous operations over an asynchronous source. + + :raises TypeError: If source is an AsyncGenerator. + """ + if isinstance(self._source, AsyncGenerator): + raise TypeError( + "CrawlResultContainer source is an AsyncGenerator. Use __aiter__() to iterate over it." + ) + + def __iter__(self) -> Iterator[CrawlResult]: # pyright: ignore[reportIncompatibleMethodOverride] + """Returns an iterator for the crawl results. + + This method is used for synchronous iteration. + + :return: An iterator for the crawl results. + :rtype: Iterator[CrawlResult] + :raises TypeError: If the source is an AsyncGenerator. + """ + self._raise_if_async_generator() + + return iter(self._results) + + def __aiter__(self) -> AsyncIterator[CrawlResult]: + """Returns an asynchronous iterator for the crawl results.""" + if isinstance(self._source, AsyncIterator): + return self._source.__aiter__() + + async def async_iterator() -> AsyncIterator[CrawlResult]: + for result in self._results: + yield result + + return async_iterator() + + def __getitem__(self, index: int) -> CrawlResult: + """Return the result at a given index. + + :param index: The index of the result to retrieve. + :type index: int + :return: The crawl result at the specified index. + :rtype: CrawlResult + :raises TypeError: If the source is an AsyncGenerator. + :raises IndexError: If the index is out of range. + """ + self._raise_if_async_generator() + + return self._results[index] + + def __len__(self) -> int: + """Return the number of results in the container. + + :return: The number of results. + :rtype: int + :raises TypeError: If the source is an AsyncGenerator. + """ + self._raise_if_async_generator() + + return len(self._results) + + def __getattribute__(self, attr: str) -> Any: + """Return an attribute from the first result. + + :param attr: The name of the attribute to retrieve. + :type attr: str + :return: The attribute value from the first result if present. + :rtype: Any + :raises TypeError: If the source is an AsyncGenerator. + :raises AttributeError: If the attribute does not exist. + """ + if attr.startswith("_") or attr == "source": + # Private attribute or known local field so just delegate to the parent class. + return super().__getattribute__(attr) + + try: + source: CrawlResultsT = self._source + except (AttributeError, TypeError): + # _source is not defined yet so we're in the __init__ method. + # Just delegate to the parent class. + return super().__getattribute__(attr) + + # We have a CrawlResult field. + # Local test to avoid the additional lookups from calling _raise_if_async_generator. + if isinstance(source, AsyncGenerator): + raise TypeError( + "CrawlResultContainer source is an AsyncGenerator. Use __aiter__() to iterate over it." + ) + + if not source: + # Empty source so we can't return the attribute. + raise AttributeError(f"{self.__class__.__name__} object has no results") + + # Delegate to the first result. + return super().__getattribute__(attr) + + def __repr__(self) -> str: + """Get a string representation of the container. + + The representation will be incomplete if the source is an AsyncIterator. + :return: String representation of the container. + :rtype: str + """ + + return f"{self.__class__.__name__}({self._results!r})" + class StringCompatibleMarkdown(str): """A string subclass that also provides access to MarkdownGenerationResult attributes""" def __new__(cls, markdown_result): return super().__new__(cls, markdown_result.raw_markdown) def __init__(self, markdown_result): - self._markdown_result = markdown_result - + self.markdown_result = markdown_result + def __getattr__(self, name): - return getattr(self._markdown_result, name) + return getattr(self.markdown_result, name) + # END of backward compatibility code for markdown/markdown_v2. # When removing this code in the future, make sure to: @@ -319,6 +493,12 @@ class MediaItem(BaseModel): format: Optional[str] = None width: Optional[int] = None + def __init__(self, **data): + if "width" in data and data["width"] == "undefined": + data["width"] = None + + super().__init__(**data) + class Link(BaseModel): href: Optional[str] = "" @@ -328,19 +508,19 @@ class Link(BaseModel): class Media(BaseModel): - images: List[MediaItem] = [] - videos: List[ - MediaItem - ] = [] # Using MediaItem model for now, can be extended with Video model if needed - audios: List[ - MediaItem - ] = [] # Using MediaItem model for now, can be extended with Audio model if needed - tables: List[Dict] = [] # Table data extracted from HTML tables + images: List[MediaItem] = Field(default_factory=list) + videos: List[MediaItem] = Field( + default_factory=list + ) # Using MediaItem model for now, can be extended with Video model if needed + audios: List[MediaItem] = Field( + default_factory=list + ) # Using MediaItem model for now, can be extended with Audio model if needed + tables: List[Dict] = Field(default_factory=list) # Table data extracted from HTML tables class Links(BaseModel): - internal: List[Link] = [] - external: List[Link] = [] + internal: List[Link] = Field(default_factory=list) + external: List[Link] = Field(default_factory=list) class ScrapingResult(BaseModel): @@ -348,4 +528,4 @@ class ScrapingResult(BaseModel): success: bool media: Media = Media() links: Links = Links() - metadata: Dict[str, Any] = {} + metadata: Dict[str, Any] = Field(default_factory=dict) diff --git a/crawl4ai/proxy_strategy.py b/crawl4ai/proxy_strategy.py index 6821c566f..1aa773fcc 100644 --- a/crawl4ai/proxy_strategy.py +++ b/crawl4ai/proxy_strategy.py @@ -1,4 +1,4 @@ -from typing import List, Dict, Optional +from typing import List, Optional, Dict from abc import ABC, abstractmethod from itertools import cycle import os @@ -119,19 +119,19 @@ class ProxyRotationStrategy(ABC): """Base abstract class for proxy rotation strategies""" @abstractmethod - async def get_next_proxy(self) -> Optional[Dict]: + async def get_next_proxy(self) -> Optional[ProxyConfig]: """Get next proxy configuration from the strategy""" pass @abstractmethod - def add_proxies(self, proxies: List[Dict]): + def add_proxies(self, proxies: List[ProxyConfig]): """Add proxy configurations to the strategy""" pass class RoundRobinProxyStrategy: """Simple round-robin proxy rotation strategy using ProxyConfig objects""" - def __init__(self, proxies: List[ProxyConfig] = None): + def __init__(self, proxies: Optional[List[ProxyConfig]] = None): """ Initialize with optional list of proxy configurations diff --git a/crawl4ai/types.py b/crawl4ai/types.py index 63fd45bae..dc90a2a51 100644 --- a/crawl4ai/types.py +++ b/crawl4ai/types.py @@ -10,6 +10,7 @@ CrawlResult = Union['CrawlResultType'] CrawlerHub = Union['CrawlerHubType'] BrowserProfiler = Union['BrowserProfilerType'] +CrawlResultContainer = Union['CrawlResultContainerType'] # Configuration types BrowserConfig = Union['BrowserConfigType'] @@ -54,7 +55,6 @@ RateLimiter = Union['RateLimiterType'] CrawlerMonitor = Union['CrawlerMonitorType'] DisplayMode = Union['DisplayModeType'] -RunManyReturn = Union['RunManyReturnType'] # Docker client Crawl4aiDockerClient = Union['Crawl4aiDockerClientType'] @@ -91,7 +91,10 @@ AsyncWebCrawler as AsyncWebCrawlerType, CacheMode as CacheModeType, ) - from .models import CrawlResult as CrawlResultType + from .models import ( + CrawlResult as CrawlResultType, + CrawlResultContainer as CrawlResultContainerType, + ) from .hub import CrawlerHub as CrawlerHubType from .browser_profiler import BrowserProfiler as BrowserProfilerType @@ -153,7 +156,6 @@ RateLimiter as RateLimiterType, CrawlerMonitor as CrawlerMonitorType, DisplayMode as DisplayModeType, - RunManyReturn as RunManyReturnType, ) # Docker client @@ -179,9 +181,3 @@ DFSDeepCrawlStrategy as DFSDeepCrawlStrategyType, DeepCrawlDecorator as DeepCrawlDecoratorType, ) - - - -def create_llm_config(*args, **kwargs) -> 'LLMConfigType': - from .async_configs import LLMConfig - return LLMConfig(*args, **kwargs) diff --git a/crawl4ai/user_agent_generator.py b/crawl4ai/user_agent_generator.py index df2125680..c1df7e260 100644 --- a/crawl4ai/user_agent_generator.py +++ b/crawl4ai/user_agent_generator.py @@ -1,5 +1,5 @@ import random -from typing import Optional, Literal, List, Dict, Tuple +from typing import Optional, Literal, List, Dict, Tuple, Union import re from abc import ABC, abstractmethod diff --git a/crawl4ai/utils.py b/crawl4ai/utils.py index 02d105a94..a02e881f3 100644 --- a/crawl4ai/utils.py +++ b/crawl4ai/utils.py @@ -1,22 +1,22 @@ import time from concurrent.futures import ThreadPoolExecutor, as_completed -from bs4 import BeautifulSoup, Comment, element, Tag, NavigableString +from bs4 import BeautifulSoup +from bs4.element import Comment, Tag, NavigableString, PageElement import json import html -import lxml +from lxml.html import fromstring, tostring, document_fromstring import re import os import platform from .prompts import PROMPT_EXTRACT_BLOCKS from array import array from .html2text import html2text, CustomHTML2Text -# from .config import * from .config import MIN_WORD_THRESHOLD, IMAGE_DESCRIPTION_MIN_WORD_THRESHOLD, IMAGE_SCORE_THRESHOLD, DEFAULT_PROVIDER, PROVIDER_MODELS import httpx from socket import gaierror from pathlib import Path from typing import Dict, Any, List, Optional, Callable -from urllib.parse import urljoin +from urllib.parse import ParseResult, urljoin import requests from requests.exceptions import InvalidSchema import xxhash @@ -134,7 +134,7 @@ def merge_chunks( target_size: int, overlap: int = 0, word_token_ratio: float = 1.0, - splitter: Callable = None + splitter: Optional[Callable] = None, ) -> List[str]: """Merges documents into chunks of specified token size. @@ -509,8 +509,8 @@ def calculate_semaphore_count(): int: The calculated semaphore count. """ - cpu_count = os.cpu_count() - memory_gb = get_system_memory() / (1024**3) # Convert to GB + cpu_count = os.cpu_count() or 1 + memory_gb = get_system_memory() or 0 / (1024**3) # Convert to GB base_count = max(1, cpu_count // 2) memory_based_cap = int(memory_gb / 2) # Assume 2GB per instance return min(base_count, memory_based_cap) @@ -546,7 +546,7 @@ def get_system_memory(): elif system == "Windows": import ctypes - kernel32 = ctypes.windll.kernel32 + kernel32 = ctypes.windll.kernel32 # pyright: ignore[reportAttributeAccessIssue] c_ulonglong = ctypes.c_ulonglong class MEMORYSTATUSEX(ctypes.Structure): @@ -983,7 +983,7 @@ def replace_pre_tags_with_text(node): # Recursively remove empty elements, their parent elements, and elements with word count below threshold def remove_empty_and_low_word_count_elements(node, word_count_threshold): for child in node.contents: - if isinstance(child, element.Tag): + if isinstance(child, Tag): remove_empty_and_low_word_count_elements( child, word_count_threshold ) @@ -1051,7 +1051,7 @@ def remove_empty_tags(body: Tag): # Flatten nested elements with only one child of the same type def flatten_nested_elements(node): for child in node.contents: - if isinstance(child, element.Tag): + if isinstance(child, Tag): flatten_nested_elements(child) if ( len(child.contents) == 1 @@ -1108,9 +1108,9 @@ def get_content_of_website_optimized( url: str, html: str, word_count_threshold: int = MIN_WORD_THRESHOLD, - css_selector: str = None, + css_selector: Optional[str] = None, **kwargs, -) -> Dict[str, Any]: +) -> Optional[Dict[str, Any]]: if not html: return None @@ -1243,7 +1243,7 @@ def fetch_image_file_size(img, base_url): "type": "image", } - def process_element(element: element.PageElement) -> bool: + def process_element(element: PageElement) -> bool: try: if isinstance(element, NavigableString): if isinstance(element, Comment): @@ -1364,7 +1364,7 @@ def flatten_nested_elements(node): return node if ( len(node.contents) == 1 - and isinstance(node.contents[0], element.Tag) + and isinstance(node.contents[0], Tag) and node.contents[0].name == node.name ): return flatten_nested_elements(node.contents[0]) @@ -1416,7 +1416,7 @@ def extract_metadata_using_lxml(html, doc=None): if doc is None: try: - doc = lxml.html.document_fromstring(html) + doc = document_fromstring(html) except Exception: return {} @@ -1928,7 +1928,7 @@ def wrap_text(draw, text, font, max_width): return "\n".join(lines) -def format_html(html_string): +def format_html(html_string) -> str: """ Prettify an HTML string using BeautifulSoup. @@ -1944,8 +1944,8 @@ def format_html(html_string): str: The prettified HTML string. """ - soup = BeautifulSoup(html_string, "lxml.parser") - return soup.prettify() + soup = BeautifulSoup(html_string, "lxml") + return soup.prettify() # pyright: ignore[reportReturnType] def fast_format_html(html_string): @@ -2008,14 +2008,10 @@ def normalize_url(href, base_url): return normalized -def normalize_url_for_deep_crawl(href, base_url): +def normalize_url_for_deep_crawl(href: str, base_url: str) -> str: """Normalize URLs to ensure consistent format""" from urllib.parse import urljoin, urlparse, urlunparse, parse_qs, urlencode - # Handle None or empty values - if not href: - return None - # Use urljoin to handle relative URLs full_url = urljoin(base_url, href.strip()) @@ -2119,6 +2115,7 @@ def normalize_url_tmp(href, base_url): return href.strip() +DEFAULT_PORTS = {"http": 80, "https": 443} def get_base_domain(url: str) -> str: """ @@ -2126,7 +2123,7 @@ def get_base_domain(url: str) -> str: How it works: 1. Parses the URL to extract the domain. - 2. Removes the port number and 'www' prefix. + 2. Removes the port number and 'www' prefix if necessary. 3. Handles special domains (e.g., 'co.uk') to extract the correct base. Args: @@ -2136,8 +2133,8 @@ def get_base_domain(url: str) -> str: str: The extracted base domain or an empty string if parsing fails. """ try: - # Get domain from URL - domain = urlparse(url).netloc.lower() + parsed: ParseResult = urlparse(url) + domain = parsed.netloc.lower() if not domain: return "" @@ -2145,7 +2142,14 @@ def get_base_domain(url: str) -> str: domain = domain.split(":")[0] # Remove www - domain = re.sub(r"^www\.", "", domain) + domain = domain.removeprefix("www.") + + port_suffix: str = "" + port = parsed.port + if port is not None and port != DEFAULT_PORTS.get(parsed.scheme): + # Port needed. + port_suffix = f":{port}" + # Extract last two parts of domain (handles co.uk etc) parts = domain.split(".") @@ -2164,9 +2168,9 @@ def get_base_domain(url: str) -> str: "af", "ag", }: - return ".".join(parts[-3:]) + return ".".join(parts[-3:]) + port_suffix - return ".".join(parts[-2:]) + return ".".join(parts[-2:]) + port_suffix except Exception: return "" @@ -2524,7 +2528,7 @@ def configure_windows_event_loop(): ``` """ if platform.system() == "Windows": - asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy()) + asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy()) # pyright: ignore[reportAttributeAccessIssue] def get_error_context(exc_info, context_lines: int = 5): @@ -2588,9 +2592,9 @@ def truncate(value, threshold): return value[:threshold] + '...' # Add ellipsis to indicate truncation return value -def optimize_html(html_str, threshold=200): - root = lxml.html.fromstring(html_str) - +def optimize_html(html_str, threshold=200) -> str: + root = fromstring(html_str) + for _element in root.iter(): # Process attributes for attr in list(_element.attrib): @@ -2603,8 +2607,9 @@ def optimize_html(html_str, threshold=200): # Process tail text if _element.tail and len(_element.tail) > threshold: _element.tail = truncate(_element.tail, threshold) - - return lxml.html.tostring(root, encoding='unicode', pretty_print=False) + + return tostring(root, encoding="unicode", pretty_print=False) # pyright: ignore[reportReturnType] + class HeadPeekr: @staticmethod @@ -2659,6 +2664,7 @@ def extract_meta_tags(head_content: str): return meta_tags + @staticmethod def get_title(head_content: str): title_match = re.search(r'(.*?)', head_content, re.IGNORECASE | re.DOTALL) return title_match.group(1) if title_match else None diff --git a/deploy/docker/api.py b/deploy/docker/api.py index 338027721..7b23f116e 100644 --- a/deploy/docker/api.py +++ b/deploy/docker/api.py @@ -1,16 +1,17 @@ import os import json import asyncio -from typing import List, Tuple +from dataclasses import dataclass +from typing import List, Tuple, Optional, Annotated, AsyncGenerator, Any from functools import partial import logging -from typing import Optional, AsyncGenerator from urllib.parse import unquote -from fastapi import HTTPException, Request, status +from fastapi import HTTPException, Request, status, Body from fastapi.background import BackgroundTasks from fastapi.responses import JSONResponse from redis import asyncio as aioredis +from pydantic.fields import Field from crawl4ai import ( AsyncWebCrawler, @@ -42,11 +43,26 @@ logger = logging.getLogger(__name__) -async def handle_llm_qa( - url: str, - query: str, - config: dict -) -> str: + +@dataclass +class CrawlRequest: + urls: List[str] + browser_config: BrowserConfig + crawler_config: CrawlerRunConfig + + def __init__( + self, + urls: Annotated[List[str], Field(min_length=1, max_length=100)] = Body(None), + browser_config: Optional[dict[str, Any]] = Body(None), + crawler_config: Optional[dict[str, Any]] = Body(None), + ) -> None: + """Build a CrawlRequest object from request body.""" + self.urls = urls + self.browser_config = BrowserConfig.load(browser_config) + self.crawler_config = CrawlerRunConfig.load(crawler_config) + + +async def handle_llm_qa(url: str, query: str, config: dict) -> str: """Process QA using LLM with crawled content as context.""" try: # Extract base URL by finding last '?q=' occurrence @@ -151,9 +167,9 @@ async def process_llm_extraction( async def handle_markdown_request( url: str, filter_type: FilterType, + config: dict, query: Optional[str] = None, cache: str = "0", - config: Optional[dict] = None ) -> str: """Handle markdown generation requests.""" try: @@ -371,16 +387,11 @@ async def stream_results(crawler: AsyncWebCrawler, results_gen: AsyncGenerator) logger.error(f"Crawler cleanup error: {e}") async def handle_crawl_request( - urls: List[str], - browser_config: dict, - crawler_config: dict, - config: dict + crawl_request: CrawlRequest, + config: dict, ) -> dict: """Handle non-streaming crawl requests.""" try: - browser_config = BrowserConfig.load(browser_config) - crawler_config = CrawlerRunConfig.load(crawler_config) - dispatcher = MemoryAdaptiveDispatcher( memory_threshold_percent=config["crawler"]["memory_threshold_percent"], rate_limiter=RateLimiter( @@ -388,19 +399,20 @@ async def handle_crawl_request( ) ) - async with AsyncWebCrawler(config=browser_config) as crawler: + async with AsyncWebCrawler(config=crawl_request.browser_config) as crawler: results = [] - func = getattr(crawler, "arun" if len(urls) == 1 else "arun_many") - partial_func = partial(func, - urls[0] if len(urls) == 1 else urls, - config=crawler_config, - dispatcher=dispatcher) + func = getattr(crawler, "arun" if len(crawl_request.urls) == 1 else "arun_many") + partial_func = partial(func, + crawl_request.urls[0] if len(crawl_request.urls) == 1 else crawl_request.urls, + config=crawl_request.crawler_config, + dispatcher=dispatcher,) results = await partial_func() return { "success": True, "results": [result.model_dump() for result in results] } - + except HTTPException: + raise except Exception as e: logger.error(f"Crawl error: {str(e)}", exc_info=True) raise HTTPException( @@ -409,17 +421,14 @@ async def handle_crawl_request( ) async def handle_stream_crawl_request( - urls: List[str], - browser_config: dict, - crawler_config: dict, - config: dict + crawl_request: CrawlRequest, + config: dict, ) -> Tuple[AsyncWebCrawler, AsyncGenerator]: """Handle streaming crawl requests.""" + crawler: Optional[AsyncWebCrawler] = None try: - browser_config = BrowserConfig.load(browser_config) - browser_config.verbose = True - crawler_config = CrawlerRunConfig.load(crawler_config) - crawler_config.scraping_strategy = LXMLWebScrapingStrategy() + crawl_request.browser_config.verbose = True + crawl_request.crawler_config.scraping_strategy = LXMLWebScrapingStrategy() dispatcher = MemoryAdaptiveDispatcher( memory_threshold_percent=config["crawler"]["memory_threshold_percent"], @@ -428,22 +437,29 @@ async def handle_stream_crawl_request( ) ) - crawler = AsyncWebCrawler(config=browser_config) + crawler = AsyncWebCrawler(config=crawl_request.browser_config) await crawler.start() results_gen = await crawler.arun_many( - urls=urls, - config=crawler_config, - dispatcher=dispatcher + urls=crawl_request.urls, + config=crawl_request.crawler_config, + dispatcher=dispatcher, ) + if not isinstance(results_gen.source, AsyncGenerator): + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Unexpected results type {type(results_gen.source)} expected AsyncGenerator", + ) - return crawler, results_gen + return crawler, results_gen.source + except HTTPException: + raise except Exception as e: - if 'crawler' in locals(): - await crawler.close() logger.error(f"Stream crawl error: {str(e)}", exc_info=True) + if crawler is not None: + await crawler.close() raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e) - ) \ No newline at end of file + ) diff --git a/deploy/docker/server.py b/deploy/docker/server.py index edb551309..c20ef6a07 100644 --- a/deploy/docker/server.py +++ b/deploy/docker/server.py @@ -1,17 +1,18 @@ import os import sys import time -from typing import List, Optional, Dict -from fastapi import FastAPI, HTTPException, Request, Query, Path, Depends +from typing import Optional, Dict, Any, Annotated +from fastapi import FastAPI, HTTPException, Request, Query, Path, Depends, status from fastapi.responses import StreamingResponse, RedirectResponse, PlainTextResponse, JSONResponse from fastapi.middleware.httpsredirect import HTTPSRedirectMiddleware from fastapi.middleware.trustedhost import TrustedHostMiddleware -from pydantic import BaseModel, Field from slowapi import Limiter from slowapi.util import get_remote_address from prometheus_fastapi_instrumentator import Instrumentator from redis import asyncio as aioredis +from crawl4ai.async_configs import Serialisable + sys.path.append(os.path.dirname(os.path.realpath(__file__))) from utils import FilterType, load_config, setup_logging, verify_email_domain from api import ( @@ -19,16 +20,13 @@ handle_llm_qa, handle_stream_crawl_request, handle_crawl_request, - stream_results + stream_results, + CrawlRequest ) from auth import create_access_token, get_token_dependency, TokenRequest # Import from auth.py __version__ = "0.2.6" -class CrawlRequest(BaseModel): - urls: List[str] = Field(min_length=1, max_length=100) - browser_config: Optional[Dict] = Field(default_factory=dict) - crawler_config: Optional[Dict] = Field(default_factory=dict) # Load configuration and setup config = load_config() @@ -79,7 +77,9 @@ async def add_security_headers(request: Request, call_next): @app.post("/token") async def get_token(request_data: TokenRequest): if not verify_email_domain(request_data.email): - raise HTTPException(status_code=400, detail="Invalid email domain") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid email domain" + ) token = create_access_token({"sub": request_data.email}) return {"email": request_data.email, "access_token": token, "token_type": "bearer"} @@ -91,10 +91,10 @@ async def get_markdown( url: str, f: FilterType = FilterType.FIT, q: Optional[str] = None, - c: Optional[str] = "0", - token_data: Optional[Dict] = Depends(token_dependency) -): - result = await handle_markdown_request(url, f, q, c, config) + c: str = "0", + token_data: Optional[Dict] = Depends(token_dependency), +) -> PlainTextResponse: + result = await handle_markdown_request(url, f, config, q, c) return PlainTextResponse(result) @app.get("/llm/{url:path}", description="URL should be without http/https prefix") @@ -102,8 +102,9 @@ async def llm_endpoint( request: Request, url: str = Path(...), q: Optional[str] = Query(None), - token_data: Optional[Dict] = Depends(token_dependency) -): + # TODO: Add schema and cache support as per get_markdown + token_data: Optional[Dict] = Depends(token_dependency), +) -> JSONResponse: if not q: raise HTTPException(status_code=400, detail="Query parameter 'q' is required") if not url.startswith(('http://', 'https://')): @@ -115,33 +116,39 @@ async def llm_endpoint( raise HTTPException(status_code=500, detail=str(e)) @app.get("/schema") -async def get_schema(): +async def get_schema() -> dict[str, Serialisable]: from crawl4ai import BrowserConfig, CrawlerRunConfig return {"browser": BrowserConfig().dump(), "crawler": CrawlerRunConfig().dump()} @app.get(config["observability"]["health_check"]["endpoint"]) -async def health(): +async def health() -> dict[str, Any]: return {"status": "ok", "timestamp": time.time(), "version": __version__} @app.get(config["observability"]["prometheus"]["endpoint"]) -async def metrics(): +async def metrics() -> RedirectResponse: return RedirectResponse(url=config["observability"]["prometheus"]["endpoint"]) @app.post("/crawl") @limiter.limit(config["rate_limiting"]["default_limit"]) async def crawl( request: Request, - crawl_request: CrawlRequest, - token_data: Optional[Dict] = Depends(token_dependency) -): + crawl_request: Annotated[CrawlRequest, Depends(CrawlRequest)], + token_data: Optional[Dict] = Depends(token_dependency), +) -> JSONResponse: if not crawl_request.urls: - raise HTTPException(status_code=400, detail="At least one URL required") - + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail="At least one URL required" + ) + + if crawl_request.crawler_config.stream: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Streaming mode not allowed for this endpoint. Use /crawl/stream instead.", + ) + results = await handle_crawl_request( - urls=crawl_request.urls, - browser_config=crawl_request.browser_config, - crawler_config=crawl_request.crawler_config, - config=config + crawl_request=crawl_request, + config=config, ) return JSONResponse(results) @@ -151,17 +158,23 @@ async def crawl( @limiter.limit(config["rate_limiting"]["default_limit"]) async def crawl_stream( request: Request, - crawl_request: CrawlRequest, - token_data: Optional[Dict] = Depends(token_dependency) + crawl_request: Annotated[CrawlRequest, Depends(CrawlRequest)], + token_data: Optional[Dict] = Depends(token_dependency), ): if not crawl_request.urls: - raise HTTPException(status_code=400, detail="At least one URL required") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail="At least one URL required" + ) + + if not crawl_request.crawler_config.stream: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Streaming mode must be set for this endpoint. Use /crawl instead {crawl_request.crawler_config}.", + ) crawler, results_gen = await handle_stream_crawl_request( - urls=crawl_request.urls, - browser_config=crawl_request.browser_config, - crawler_config=crawl_request.crawler_config, - config=config + crawl_request=crawl_request, + config=config, ) return StreamingResponse( diff --git a/docker-compose.yml b/docker-compose.yml index 6a7bf7cbc..d2f22aacc 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -38,7 +38,7 @@ services: platforms: - linux/amd64 profiles: ["local-amd64"] - <<: *base-config # extends yerine doğrudan yapılandırmayı dahil ettik + <<: *base-config # extends we included the configuration directly instead crawl4ai-arm64: build: diff --git a/docs/examples/async_webcrawler_multiple_urls_example.py b/docs/examples/async_webcrawler_multiple_urls_example.py index 52309d13e..93b07bf55 100644 --- a/docs/examples/async_webcrawler_multiple_urls_example.py +++ b/docs/examples/async_webcrawler_multiple_urls_example.py @@ -8,7 +8,7 @@ sys.path.append(parent_dir) import asyncio -from crawl4ai import AsyncWebCrawler +from crawl4ai import AsyncWebCrawler, CacheMode async def main(): @@ -30,7 +30,7 @@ async def main(): results = await crawler.arun_many( urls=urls, word_count_threshold=word_count_threshold, - bypass_cache=True, + cache_mode=CacheMode.BYPASS, verbose=True, ) diff --git a/docs/examples/crawlai_vs_firecrawl.py b/docs/examples/crawlai_vs_firecrawl.py index f8b70dc70..acb5cbdc0 100644 --- a/docs/examples/crawlai_vs_firecrawl.py +++ b/docs/examples/crawlai_vs_firecrawl.py @@ -6,7 +6,7 @@ sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..")) from firecrawl import FirecrawlApp -from crawl4ai import AsyncWebCrawler +from crawl4ai import AsyncWebCrawler, CacheMode __data__ = os.path.join(os.path.dirname(__file__), "..", "..") + "/.data" @@ -34,7 +34,7 @@ async def compare(): url="https://www.nbcnews.com/business", # js_code=["const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More')); loadMoreButton && loadMoreButton.click();"], word_count_threshold=0, - bypass_cache=True, + cache_mode=CacheMode.BYPASS, verbose=False, ) end = time.time() @@ -53,7 +53,7 @@ async def compare(): "const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More')); loadMoreButton && loadMoreButton.click();" ], word_count_threshold=0, - bypass_cache=True, + cache_mode=CacheMode.BYPASS, verbose=False, ) end = time.time() diff --git a/docs/examples/quickstart.ipynb b/docs/examples/quickstart.ipynb index 56365cde0..561e563f4 100644 --- a/docs/examples/quickstart.ipynb +++ b/docs/examples/quickstart.ipynb @@ -55,7 +55,7 @@ "# %%capture\n", "!pip install crawl4ai\n", "!pip install nest_asyncio\n", - "!playwright install " + "!playwright install" ] }, { @@ -65,7 +65,6 @@ "metadata": {}, "outputs": [], "source": [ - "import asyncio\n", "import nest_asyncio\n", "nest_asyncio.apply()" ] @@ -106,13 +105,13 @@ ], "source": [ "import asyncio\n", - "from crawl4ai import AsyncWebCrawler\n", + "from crawl4ai import AsyncWebCrawler, CacheMode\n", "\n", "async def simple_crawl():\n", " async with AsyncWebCrawler() as crawler:\n", " result = await crawler.arun(\n", " url=\"https://www.nbcnews.com/business\",\n", - " bypass_cache=True # By default this is False, meaning the cache will be used\n", + " cache_mode=CacheMode.BYPASS # By default this is False, meaning the cache will be used\n", " )\n", " print(result.markdown.raw_markdown[:500]) # Print the first 500 characters\n", " \n", @@ -175,7 +174,7 @@ " url=\"https://www.nbcnews.com/business\",\n", " js_code=js_code,\n", " # wait_for=wait_for,\n", - " bypass_cache=True,\n", + " cache_mode=CacheMode.BYPASS,\n", " )\n", " print(result.markdown.raw_markdown[:500]) # Print first 500 characters\n", "\n", @@ -204,7 +203,7 @@ " excluded_tags=['nav', 'footer', 'aside'],\n", " remove_overlay_elements=True,\n", " word_count_threshold=10,\n", - " bypass_cache=True\n", + " cache_mode=CacheMode.BYPASS\n", " )\n", " full_markdown_length = len(result.markdown.raw_markdown)\n", " fit_markdown_length = len(result.markdown.fit_markdown)\n", @@ -264,7 +263,7 @@ " async with AsyncWebCrawler() as crawler:\n", " result = await crawler.arun(\n", " url=\"https://www.nbcnews.com/business\",\n", - " bypass_cache=True,\n", + " cache_mode=CacheMode.BYPASS,\n", " exclude_external_links=True,\n", " exclude_social_media_links=True,\n", " # exclude_domains=[\"facebook.com\", \"twitter.com\"]\n", @@ -314,7 +313,7 @@ " async with AsyncWebCrawler() as crawler:\n", " result = await crawler.arun(\n", " url=\"https://www.nbcnews.com/business\", \n", - " bypass_cache=True,\n", + " cache_mode=CacheMode.BYPASS,\n", " exclude_external_images=False,\n", " screenshot=True\n", " )\n", @@ -385,7 +384,7 @@ " # Perform the crawl operation\n", " result = await crawler.arun(\n", " url=\"https://crawl4ai.com\",\n", - " bypass_cache=True\n", + " cache_mode=CacheMode.BYPASS\n", " )\n", " print(result.markdown.raw_markdown[:500]) # Display the first 500 characters\n", "\n", @@ -462,7 +461,7 @@ " session_id=session_id,\n", " js_code=\"document.querySelector('.next-page-button').click();\" if page_number > 1 else None,\n", " css_selector=\".content-section\",\n", - " bypass_cache=True\n", + " cache_mode=CacheMode.BYPASS\n", " )\n", " print(f\"Page {page_number} Content:\")\n", " print(result.markdown.raw_markdown[:500]) # Print first 500 characters\n", @@ -552,7 +551,7 @@ " \"{model_name: 'GPT-4', input_fee: 'US$10.00 / 1M tokens', output_fee: 'US$30.00 / 1M tokens'}.\"\"\",\n", " **extra_args\n", " ),\n", - " bypass_cache=True,\n", + " cache_mode=CacheMode.BYPASS,\n", " )\n", " print(json.loads(result.extracted_content)[:5])\n", "\n", diff --git a/docs/examples/summarize_page.py b/docs/examples/summarize_page.py index da2bcd219..3eb70255d 100644 --- a/docs/examples/summarize_page.py +++ b/docs/examples/summarize_page.py @@ -1,9 +1,8 @@ import os import json -from crawl4ai.web_crawler import WebCrawler -from crawl4ai.chunking_strategy import * -from crawl4ai.extraction_strategy import * -from crawl4ai.crawler_strategy import * +from crawl4ai import CacheMode +from crawl4ai.extraction_strategy import LLMExtractionStrategy +from crawl4ai.legacy.web_crawler import WebCrawler url = r"https://marketplace.visualstudio.com/items?itemName=Unclecode.groqopilot" @@ -37,7 +36,7 @@ class PageSummary(BaseModel): "The extracted JSON format should look like this: " '{ "title": "Page Title", "summary": "Detailed summary of the page.", "brief_summary": "Brief summary in a paragraph.", "keywords": ["keyword1", "keyword2", "keyword3"] }', ), - bypass_cache=True, + cache_mode=CacheMode.BYPASS, ) page_summary = json.loads(result.extracted_content) diff --git a/docs/md_v2/core/local-files.md b/docs/md_v2/core/local-files.md index ddf27f8c8..a8530956c 100644 --- a/docs/md_v2/core/local-files.md +++ b/docs/md_v2/core/local-files.md @@ -8,11 +8,11 @@ To crawl a live web page, provide the URL starting with `http://` or `https://`, ```python import asyncio -from crawl4ai import AsyncWebCrawler +from crawl4ai import AsyncWebCrawler, CacheMode from crawl4ai.async_configs import CrawlerRunConfig async def crawl_web(): - config = CrawlerRunConfig(bypass_cache=True) + config = CrawlerRunConfig(cache_mode=CacheMode.BYPASS) async with AsyncWebCrawler() as crawler: result = await crawler.arun( url="https://en.wikipedia.org/wiki/apple", @@ -39,8 +39,8 @@ from crawl4ai.async_configs import CrawlerRunConfig async def crawl_local_file(): local_file_path = "/path/to/apple.html" # Replace with your file path file_url = f"file://{local_file_path}" - config = CrawlerRunConfig(bypass_cache=True) - + config = CrawlerRunConfig(cache_mode=CacheMode.BYPASS) + async with AsyncWebCrawler() as crawler: result = await crawler.arun(url=file_url, config=config) if result.success: @@ -64,8 +64,8 @@ from crawl4ai.async_configs import CrawlerRunConfig async def crawl_raw_html(): raw_html = "

Hello, World!

" raw_html_url = f"raw:{raw_html}" - config = CrawlerRunConfig(bypass_cache=True) - + config = CrawlerRunConfig(cache_mode=CacheMode.BYPASS) + async with AsyncWebCrawler() as crawler: result = await crawler.arun(url=raw_html_url, config=config) if result.success: @@ -104,7 +104,7 @@ async def main(): async with AsyncWebCrawler() as crawler: # Step 1: Crawl the Web URL print("\n=== Step 1: Crawling the Wikipedia URL ===") - web_config = CrawlerRunConfig(bypass_cache=True) + web_config = CrawlerRunConfig(cache_mode=CacheMode.BYPASS) result = await crawler.arun(url=wikipedia_url, config=web_config) if not result.success: @@ -119,7 +119,7 @@ async def main(): # Step 2: Crawl from the Local HTML File print("=== Step 2: Crawling from the Local HTML File ===") file_url = f"file://{html_file_path.resolve()}" - file_config = CrawlerRunConfig(bypass_cache=True) + file_config = CrawlerRunConfig(cache_mode=CacheMode.BYPASS) local_result = await crawler.arun(url=file_url, config=file_config) if not local_result.success: @@ -135,7 +135,7 @@ async def main(): with open(html_file_path, 'r', encoding='utf-8') as f: raw_html_content = f.read() raw_html_url = f"raw:{raw_html_content}" - raw_config = CrawlerRunConfig(bypass_cache=True) + raw_config = CrawlerRunConfig(cache_mode=CacheMode.BYPASS) raw_result = await crawler.arun(url=raw_html_url, config=raw_config) if not raw_result.success: diff --git a/docs/releases_review/Crawl4AI_v0.3.72_Release_Announcement.ipynb b/docs/releases_review/Crawl4AI_v0.3.72_Release_Announcement.ipynb index 053bc6c59..641a8ce28 100644 --- a/docs/releases_review/Crawl4AI_v0.3.72_Release_Announcement.ipynb +++ b/docs/releases_review/Crawl4AI_v0.3.72_Release_Announcement.ipynb @@ -149,10 +149,10 @@ "metadata": {}, "outputs": [], "source": [ + "from crawl4ai import CacheMode\n", "from crawl4ai.extraction_strategy import LLMExtractionStrategy\n", "from pydantic import BaseModel\n", "import json, os\n", - "from typing import List\n", "\n", "# Define classes for the knowledge graph structure\n", "class Landmark(BaseModel):\n", @@ -187,7 +187,7 @@ " result = await crawler.arun(\n", " url=\"https://janineintheworld.com/places-to-visit-in-central-mexico\",\n", " extraction_strategy=strategy,\n", - " bypass_cache=True,\n", + " cache_mode=CacheMode.BYPASS,\n", " magic=True\n", " )\n", " \n", diff --git a/pyproject.toml b/pyproject.toml index ad07548d6..7b6aa4b84 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,5 @@ [build-system] -requires = ["setuptools>=64.0.0", "wheel"] +requires = ["setuptools>=78.0.1", "wheel"] build-backend = "setuptools.build_meta" [project] @@ -8,7 +8,7 @@ dynamic = ["version"] description = "🚀🤖 Crawl4AI: Open-source LLM Friendly Web Crawler & scraper" readme = "README.md" requires-python = ">=3.9" -license = {text = "MIT"} +license = "MIT" authors = [ {name = "Unclecode", email = "unclecode@kidocode.com"} ] @@ -43,11 +43,13 @@ dependencies = [ "faust-cchardet>=2.1.19", "aiohttp>=3.11.11", "humanize>=4.10.0", + "pdf2image>=1.17.0", + "bitarray>=3.2.0", + "mmh3>=5.1.0", ] classifiers = [ "Development Status :: 4 - Beta", "Intended Audience :: Developers", - "License :: OSI Approved :: Apache Software License", "Programming Language :: Python :: 3", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", @@ -57,20 +59,17 @@ classifiers = [ ] [project.optional-dependencies] -pdf = ["PyPDF2"] +pdf = ["PyPDF2", "pdf2image"] torch = ["torch", "nltk", "scikit-learn"] transformer = ["transformers", "tokenizers"] cosine = ["torch", "transformers", "nltk"] sync = ["selenium"] all = [ - "PyPDF2", - "torch", - "nltk", - "scikit-learn", - "transformers", - "tokenizers", - "selenium", - "PyPDF2" + "crawl4ai[cosine]", + "crawl4ai[pdf]", + "crawl4ai[sync]", + "crawl4ai[torch]", + "crawl4ai[transformer]", ] [project.scripts] @@ -96,3 +95,44 @@ crawl4ai = { workspace = true } dev = [ "crawl4ai", ] +docker = [ + "fastapi>=0.115.11", + "redis>=5.2.1", +] +test = [ + "matplotlib>=3.9.4", + "nest-asyncio>=1.6.0", + "nltk>=3.9.1", + "pytest-aiohttp>=1.1.0", + "pytest-asyncio>=0.25.3", + "pytest-cov>=6.0.0", + "pytest-httpserver>=1.1.2", + "pytest-timeout>=2.3.1", + "pytest>=8.3.5", + "scipy>=1.13.1", + "selenium>=4.29.0", + "tabulate>=0.9.0", + "torch>=2.6.0", + "transformers>=4.49.0", +] + +[tool.pytest.ini_options] +asyncio_default_fixture_loop_scope = "function" # Prevent deprecation warning +# Disable deprecation warnings for code we don't control. +filterwarnings = [ + "ignore::DeprecationWarning", + "default::DeprecationWarning:__main__", + "default::DeprecationWarning:crawl4ai.*", +] +timeout = 20 +timeout_func_only = true + +# Basic configuration for the `ruff` tool to preserve basic formatting. +[tool.ruff] +line-length = 120 +target-version = "py39" + +# As a temporary workaround, exclude all files from ruff formatting as the +# code base has various code style and would introduce a lot of noise. +[tool.ruff.format] +exclude = ["*.py"] diff --git a/setup.py b/setup.py index 16b1b53cb..ef584ff7a 100644 --- a/setup.py +++ b/setup.py @@ -55,7 +55,6 @@ classifiers=[ "Development Status :: 3 - Alpha", "Intended Audience :: Developers", - "License :: OSI Approved :: Apache Software License", "Programming Language :: Python :: 3", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", diff --git a/tests/20241401/test_advanced_deep_crawl.py b/tests/20241401/test_advanced_deep_crawl.py index dd291f675..c0595eb6e 100644 --- a/tests/20241401/test_advanced_deep_crawl.py +++ b/tests/20241401/test_advanced_deep_crawl.py @@ -1,16 +1,20 @@ -import asyncio +import sys import time +from httpx import codes +import pytest from crawl4ai import CrawlerRunConfig, AsyncWebCrawler, CacheMode from crawl4ai.content_scraping_strategy import LXMLWebScrapingStrategy -from crawl4ai.deep_crawling import BFSDeepCrawlStrategy, BestFirstCrawlingStrategy +from crawl4ai.deep_crawling import BestFirstCrawlingStrategy from crawl4ai.deep_crawling.filters import FilterChain, URLPatternFilter, DomainFilter, ContentTypeFilter, ContentRelevanceFilter from crawl4ai.deep_crawling.scorers import KeywordRelevanceScorer -# from crawl4ai.deep_crawling import BFSDeepCrawlStrategy, BestFirstCrawlingStrategy +from crawl4ai.types import CrawlResult -async def main(): +@pytest.mark.asyncio +@pytest.mark.timeout(60) +async def test_deep_crawl(): """Example deep crawl of documentation site.""" filter_chain = FilterChain([ URLPatternFilter(patterns=["*2025*"]), @@ -18,12 +22,14 @@ async def main(): ContentRelevanceFilter(query="Use of artificial intelligence in Defence applications", threshold=1), ContentTypeFilter(allowed_types=["text/html","application/javascript"]) ]) + max_pages: int = 5 config = CrawlerRunConfig( deep_crawl_strategy = BestFirstCrawlingStrategy( max_depth=2, include_external=False, filter_chain=filter_chain, url_scorer=KeywordRelevanceScorer(keywords=["anduril", "defence", "AI"]), + max_pages=max_pages, ), stream=False, verbose=True, @@ -35,12 +41,24 @@ async def main(): print("Starting deep crawl in streaming mode:") config.stream = True start_time = time.perf_counter() + result: CrawlResult + pages: int = 0 async for result in await crawler.arun( url="https://techcrunch.com", config=config ): + assert result.status_code == codes.OK + assert result.url + assert result.metadata + assert result.metadata.get("depth", -1) >= 0 + assert result.metadata.get("depth", -1) <= 2 + pages += 1 print(f"→ {result.url} (Depth: {result.metadata.get('depth', 0)})") - print(f"Duration: {time.perf_counter() - start_time:.2f} seconds") + + print(f"Crawled {pages} pages in: {time.perf_counter() - start_time:.2f} seconds") + assert pages == max_pages if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file + import subprocess + + sys.exit(subprocess.call(["pytest", *sys.argv[1:], sys.argv[0]])) diff --git a/tests/20241401/test_acyn_crawl_wuth_http_crawler_strategy.py b/tests/20241401/test_async_crawl_with_http_crawler_strategy.py similarity index 52% rename from tests/20241401/test_acyn_crawl_wuth_http_crawler_strategy.py rename to tests/20241401/test_async_crawl_with_http_crawler_strategy.py index 2727d1e4c..9b6b815e6 100644 --- a/tests/20241401/test_acyn_crawl_wuth_http_crawler_strategy.py +++ b/tests/20241401/test_async_crawl_with_http_crawler_strategy.py @@ -1,16 +1,30 @@ -import asyncio +import sys + +from httpx import codes +import pytest + from crawl4ai import ( AsyncWebCrawler, - CrawlerRunConfig, - HTTPCrawlerConfig, CacheMode, + CrawlerRunConfig, DefaultMarkdownGenerator, - PruningContentFilter + HTTPCrawlerConfig, + PruningContentFilter, ) from crawl4ai.async_crawler_strategy import AsyncHTTPCrawlerStrategy from crawl4ai.async_logger import AsyncLogger -async def main(): + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "url", + [ + "https://example.com", + "https://httpbin.org/get", + "raw://Test content" + ] +) +async def test_async_crawl(url: str): # Initialize HTTP crawler strategy http_strategy = AsyncHTTPCrawlerStrategy( browser_config=HTTPCrawlerConfig( @@ -27,30 +41,23 @@ async def main(): cache_mode=CacheMode.BYPASS, markdown_generator=DefaultMarkdownGenerator( content_filter=PruningContentFilter( - threshold=0.48, - threshold_type="fixed", + threshold=0.48, + threshold_type="fixed", min_word_threshold=0 ) ) ) - - # Test different URLs - urls = [ - "https://example.com", - "https://httpbin.org/get", - "raw://Test content" - ] - - for url in urls: - print(f"\n=== Testing {url} ===") - try: - result = await crawler.arun(url=url, config=crawler_config) - print(f"Status: {result.status_code}") - print(f"Raw HTML length: {len(result.html)}") - if hasattr(result, 'markdown'): - print(f"Markdown length: {len(result.markdown.raw_markdown)}") - except Exception as e: - print(f"Error: {e}") + + result = await crawler.arun(url=url, config=crawler_config) + assert result.status_code == codes.OK + assert result.html + assert result.markdown + assert result.markdown.raw_markdown + print(f"Status: {result.status_code}") + print(f"Raw HTML length: {len(result.html)}") + print(f"Markdown length: {len(result.markdown.raw_markdown)}") if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file + import subprocess + + sys.exit(subprocess.call(["pytest", *sys.argv[1:], sys.argv[0]])) diff --git a/tests/20241401/test_async_crawler_strategy.py b/tests/20241401/test_async_crawler_strategy.py index 68fe4a886..ffcfe40fd 100644 --- a/tests/20241401/test_async_crawler_strategy.py +++ b/tests/20241401/test_async_crawler_strategy.py @@ -1,13 +1,14 @@ +import os +import sys +from pathlib import Path +from typing import AsyncGenerator, Optional + +from httpx import codes import pytest import pytest_asyncio -import asyncio -from typing import Dict, Any -from pathlib import Path -from unittest.mock import MagicMock, patch -import os + from crawl4ai.async_configs import BrowserConfig, CrawlerRunConfig from crawl4ai.async_crawler_strategy import AsyncPlaywrightCrawlerStrategy -from crawl4ai.models import AsyncCrawlResponse from crawl4ai.async_logger import AsyncLogger, LogLevel CRAWL4AI_HOME_DIR = Path(os.path.expanduser("~")).joinpath(".crawl4ai") @@ -17,7 +18,7 @@ # Test Config Files @pytest.fixture -def basic_browser_config(): +def basic_browser_config() -> BrowserConfig: return BrowserConfig( browser_type="chromium", headless=True, @@ -25,20 +26,19 @@ def basic_browser_config(): ) @pytest.fixture -def advanced_browser_config(): +def advanced_browser_config() -> BrowserConfig: return BrowserConfig( - browser_type="chromium", + browser_type="chromium", headless=True, use_managed_browser=True, - user_data_dir=CRAWL4AI_HOME_DIR.joinpath("profiles", "test_profile"), - # proxy="http://localhost:8080", + user_data_dir=CRAWL4AI_HOME_DIR.joinpath("profiles", "test_profile").as_posix(), viewport_width=1920, viewport_height=1080, user_agent_mode="random" ) @pytest.fixture -def basic_crawler_config(): +def basic_crawler_config() -> CrawlerRunConfig: return CrawlerRunConfig( word_count_threshold=100, wait_until="domcontentloaded", @@ -46,12 +46,12 @@ def basic_crawler_config(): ) @pytest.fixture -def logger(): +def logger() -> AsyncLogger: return AsyncLogger(verbose=True, log_level=LogLevel.DEBUG) @pytest_asyncio.fixture -async def crawler_strategy(basic_browser_config, logger): - strategy = AsyncPlaywrightCrawlerStrategy(browser_config=basic_browser_config, logger=logger) +async def crawler_strategy(basic_browser_config, logger) -> AsyncGenerator[AsyncPlaywrightCrawlerStrategy, None]: + strategy: AsyncPlaywrightCrawlerStrategy = AsyncPlaywrightCrawlerStrategy(browser_config=basic_browser_config, logger=logger) await strategy.start() yield strategy await strategy.close() @@ -67,7 +67,7 @@ async def test_browser_config_initialization(): assert config.user_agent is not None assert config.headless is True -@pytest.mark.asyncio +@pytest.mark.asyncio async def test_persistent_browser_config(): config = BrowserConfig( use_persistent_context=True, @@ -78,17 +78,17 @@ async def test_persistent_browser_config(): # Crawler Strategy Tests @pytest.mark.asyncio -async def test_basic_page_load(crawler_strategy): +async def test_basic_page_load(crawler_strategy: AsyncPlaywrightCrawlerStrategy): response = await crawler_strategy.crawl( "https://example.com", CrawlerRunConfig() ) - assert response.status_code == 200 + assert response.status_code == codes.OK assert len(response.html) > 0 assert "Example Domain" in response.html @pytest.mark.asyncio -async def test_screenshot_capture(crawler_strategy): +async def test_screenshot_capture(crawler_strategy: AsyncPlaywrightCrawlerStrategy): config = CrawlerRunConfig(screenshot=True) response = await crawler_strategy.crawl( "https://example.com", @@ -98,17 +98,17 @@ async def test_screenshot_capture(crawler_strategy): assert len(response.screenshot) > 0 @pytest.mark.asyncio -async def test_pdf_generation(crawler_strategy): +async def test_pdf_generation(crawler_strategy: AsyncPlaywrightCrawlerStrategy): config = CrawlerRunConfig(pdf=True) response = await crawler_strategy.crawl( - "https://example.com", + "https://example.com", config ) assert response.pdf_data is not None assert len(response.pdf_data) > 0 @pytest.mark.asyncio -async def test_handle_js_execution(crawler_strategy): +async def test_handle_js_execution(crawler_strategy: AsyncPlaywrightCrawlerStrategy): config = CrawlerRunConfig( js_code="document.body.style.backgroundColor = 'red';" ) @@ -116,11 +116,11 @@ async def test_handle_js_execution(crawler_strategy): "https://example.com", config ) - assert response.status_code == 200 + assert response.status_code == codes.OK assert 'background-color: red' in response.html.lower() @pytest.mark.asyncio -async def test_multiple_js_commands(crawler_strategy): +async def test_multiple_js_commands(crawler_strategy: AsyncPlaywrightCrawlerStrategy): js_commands = [ "document.body.style.backgroundColor = 'blue';", "document.title = 'Modified Title';", @@ -131,14 +131,14 @@ async def test_multiple_js_commands(crawler_strategy): "https://example.com", config ) - assert response.status_code == 200 + assert response.status_code == codes.OK assert 'background-color: blue' in response.html.lower() assert 'id="test"' in response.html assert '>Test Content<' in response.html assert 'Modified Title' in response.html @pytest.mark.asyncio -async def test_complex_dom_manipulation(crawler_strategy): +async def test_complex_dom_manipulation(crawler_strategy: AsyncPlaywrightCrawlerStrategy): js_code = """ // Create a complex structure const container = document.createElement('div'); @@ -162,7 +162,7 @@ async def test_complex_dom_manipulation(crawler_strategy): "https://example.com", config ) - assert response.status_code == 200 + assert response.status_code == codes.OK assert 'class="test-container"' in response.html assert 'class="test-list"' in response.html assert 'class="item-1"' in response.html @@ -171,7 +171,7 @@ async def test_complex_dom_manipulation(crawler_strategy): assert '>Item 3<' in response.html @pytest.mark.asyncio -async def test_style_modifications(crawler_strategy): +async def test_style_modifications(crawler_strategy: AsyncPlaywrightCrawlerStrategy): js_code = """ const testDiv = document.createElement('div'); testDiv.id = 'style-test'; @@ -184,7 +184,7 @@ async def test_style_modifications(crawler_strategy): "https://example.com", config ) - assert response.status_code == 200 + assert response.status_code == codes.OK assert 'id="style-test"' in response.html assert 'color: green' in response.html.lower() assert 'font-size: 20px' in response.html.lower() @@ -192,7 +192,7 @@ async def test_style_modifications(crawler_strategy): assert '>Styled Content<' in response.html @pytest.mark.asyncio -async def test_dynamic_content_loading(crawler_strategy): +async def test_dynamic_content_loading(crawler_strategy: AsyncPlaywrightCrawlerStrategy): js_code = """ // Simulate dynamic content loading setTimeout(() => { @@ -213,33 +213,41 @@ async def test_dynamic_content_loading(crawler_strategy): "https://example.com", config ) - assert response.status_code == 200 + assert response.status_code == codes.OK assert 'id="loading"' in response.html assert '>Loading...Dynamically Loaded<' in response.html -# @pytest.mark.asyncio -# async def test_js_return_values(crawler_strategy): -# js_code = """ -# return { -# title: document.title, -# metaCount: document.getElementsByTagName('meta').length, -# bodyClass: document.body.className -# }; -# """ -# config = CrawlerRunConfig(js_code=js_code) -# response = await crawler_strategy.crawl( -# "https://example.com", -# config -# ) -# assert response.status_code == 200 -# assert 'Example Domain' in response.html -# assert 'meta name="viewport"' in response.html -# assert 'class="main"' in response.html +@pytest.mark.asyncio +async def test_js_return_values(crawler_strategy: AsyncPlaywrightCrawlerStrategy): + js_code = """ + return { + title: document.title, + metaCount: document.getElementsByTagName('meta').length, + bodyClass: document.body.className + }; + """ + config = CrawlerRunConfig(js_code=js_code) + response = await crawler_strategy.crawl( + "https://example.com", + config + ) + assert response.status_code == codes.OK + assert 'Example Domain' in response.html + assert 'meta name="viewport"' in response.html + assert response.js_execution_result is not None + assert response.js_execution_result.get("success") + results: Optional[list[dict]] = response.js_execution_result.get("results") + assert results + assert results[0] == { + "title": "Example Domain", + "metaCount": 3, + "bodyClass": "" + } @pytest.mark.asyncio -async def test_async_js_execution(crawler_strategy): +async def test_async_js_execution(crawler_strategy: AsyncPlaywrightCrawlerStrategy): js_code = """ await new Promise(resolve => setTimeout(resolve, 1000)); document.body.style.color = 'green'; @@ -251,34 +259,38 @@ async def test_async_js_execution(crawler_strategy): "https://example.com", config ) - assert response.status_code == 200 + assert response.status_code == codes.OK assert 'color: green' in response.html.lower() -# @pytest.mark.asyncio -# async def test_js_error_handling(crawler_strategy): -# js_code = """ -# // Intentionally cause different types of errors -# const results = []; -# try { -# nonExistentFunction(); -# } catch (e) { -# results.push(e.name); -# } -# try { -# JSON.parse('{invalid}'); -# } catch (e) { -# results.push(e.name); -# } -# return results; -# """ -# config = CrawlerRunConfig(js_code=js_code) -# response = await crawler_strategy.crawl( -# "https://example.com", -# config -# ) -# assert response.status_code == 200 -# assert 'ReferenceError' in response.html -# assert 'SyntaxError' in response.html +@pytest.mark.asyncio +async def test_js_error_handling(crawler_strategy: AsyncPlaywrightCrawlerStrategy): + js_code = """ + // Intentionally cause different types of errors + const results = []; + try { + nonExistentFunction(); + } catch (e) { + results.push(e.name); + } + try { + JSON.parse('{invalid}'); + } catch (e) { + results.push(e.name); + } + return results; + """ + config = CrawlerRunConfig(js_code=js_code) + response = await crawler_strategy.crawl( + "https://example.com", + config + ) + assert response.status_code == codes.OK + assert response.js_execution_result is not None + assert response.js_execution_result.get("success") + results: Optional[list[dict]] = response.js_execution_result.get("results") + assert results + assert 'ReferenceError' in results[0] + assert 'SyntaxError' in results[0] @pytest.mark.asyncio async def test_handle_navigation_timeout(): @@ -288,7 +300,7 @@ async def test_handle_navigation_timeout(): await strategy.crawl("https://example.com", config) @pytest.mark.asyncio -async def test_session_management(crawler_strategy): +async def test_session_management(crawler_strategy: AsyncPlaywrightCrawlerStrategy): config = CrawlerRunConfig(session_id="test_session") response1 = await crawler_strategy.crawl( "https://example.com", @@ -298,23 +310,23 @@ async def test_session_management(crawler_strategy): "https://example.com", config ) - assert response1.status_code == 200 - assert response2.status_code == 200 + assert response1.status_code == codes.OK + assert response2.status_code == codes.OK @pytest.mark.asyncio -async def test_process_iframes(crawler_strategy): +async def test_process_iframes(crawler_strategy: AsyncPlaywrightCrawlerStrategy): config = CrawlerRunConfig( process_iframes=True, wait_for_images=True ) response = await crawler_strategy.crawl( "https://example.com", - config + config ) - assert response.status_code == 200 + assert response.status_code == codes.OK @pytest.mark.asyncio -async def test_stealth_mode(crawler_strategy): +async def test_stealth_mode(crawler_strategy: AsyncPlaywrightCrawlerStrategy): config = CrawlerRunConfig( simulate_user=True, override_navigator=True @@ -323,16 +335,16 @@ async def test_stealth_mode(crawler_strategy): "https://bot.sannysoft.com", config ) - assert response.status_code == 200 + assert response.status_code == codes.OK -# Error Handling Tests +# Error Handling Tests @pytest.mark.asyncio async def test_invalid_url(): with pytest.raises(ValueError): async with AsyncPlaywrightCrawlerStrategy() as strategy: await strategy.crawl("not_a_url", CrawlerRunConfig()) -@pytest.mark.asyncio +@pytest.mark.asyncio async def test_network_error_handling(): config = CrawlerRunConfig() with pytest.raises(Exception): @@ -340,4 +352,6 @@ async def test_network_error_handling(): await strategy.crawl("https://invalid.example.com", config) if __name__ == "__main__": - pytest.main([__file__, "-v"]) \ No newline at end of file + import subprocess + + sys.exit(subprocess.call(["pytest", *sys.argv[1:], sys.argv[0]])) diff --git a/tests/20241401/test_async_markdown_generator.py b/tests/20241401/test_async_markdown_generator.py index 145b98b58..2bab64267 100644 --- a/tests/20241401/test_async_markdown_generator.py +++ b/tests/20241401/test_async_markdown_generator.py @@ -1,10 +1,11 @@ -import asyncio -from typing import Dict +import sys + +from _pytest.mark.structures import ParameterSet # pyright: ignore[reportPrivateImportUsage] from crawl4ai.content_filter_strategy import BM25ContentFilter, PruningContentFilter from crawl4ai.markdown_generation_strategy import DefaultMarkdownGenerator import time +import pytest -# Test HTML samples TEST_HTML_SAMPLES = { "basic": """ @@ -16,7 +17,7 @@ """, - + "complex": """ @@ -27,7 +28,7 @@

Important content paragraph with useful link.

Key Section

-

Detailed explanation with multiple sentences. This should be kept +

Detailed explanation with multiple sentences. This should be kept in the final output. Very important information here.

@@ -36,7 +37,7 @@ """, - + "edge_cases": """
@@ -50,10 +51,10 @@
""", - + "links_citations": """ -

Document with Links

+

Article with Links

First link to Example 1

Second link to Test 2

Image link: test image

@@ -62,110 +63,87 @@ """, } -def test_content_filters() -> Dict[str, Dict[str, int]]: +GENERATORS = { + "no_filter": DefaultMarkdownGenerator(), + "pruning": DefaultMarkdownGenerator( + content_filter=PruningContentFilter(threshold=0.48) + ), + "bm25": DefaultMarkdownGenerator( + content_filter=BM25ContentFilter( + user_query="test article content important" + ) + ) +} + + +def filter_params() -> list[ParameterSet]: + """Return a list of test parameters for the content filter tests.""" + return [ + pytest.param(html, id=name) for name, html in TEST_HTML_SAMPLES.items() + ] + +@pytest.mark.parametrize("html", filter_params()) +def test_content_filters(html: str): """Test various content filtering strategies and return length comparisons.""" - results = {} - # Initialize filters pruning_filter = PruningContentFilter( threshold=0.48, threshold_type="fixed", min_word_threshold=2 ) - + bm25_filter = BM25ContentFilter( bm25_threshold=1.0, user_query="test article content important" ) - - # Test each HTML sample - for test_name, html in TEST_HTML_SAMPLES.items(): - # Store results for this test case - results[test_name] = {} - - # Test PruningContentFilter - start_time = time.time() - pruned_content = pruning_filter.filter_content(html) - pruning_time = time.time() - start_time - - # Test BM25ContentFilter - start_time = time.time() - bm25_content = bm25_filter.filter_content(html) - bm25_time = time.time() - start_time - - # Store results - results[test_name] = { - "original_length": len(html), - "pruned_length": sum(len(c) for c in pruned_content), - "bm25_length": sum(len(c) for c in bm25_content), - "pruning_time": pruning_time, - "bm25_time": bm25_time - } - - return results - -def test_markdown_generation(): + + # Test PruningContentFilter + start_time = time.time() + pruned_content = pruning_filter.filter_content(html) + pruning_time = time.time() - start_time + + # Test BM25ContentFilter + start_time = time.time() + bm25_content = bm25_filter.filter_content(html) + bm25_time = time.time() - start_time + + assert len(pruned_content) > 0 + assert len(bm25_content) > 0 + print(f"Original length: {len(html)}") + print(f"Pruned length: {sum(len(c) for c in pruned_content)} ({pruning_time:.3f}s)") + print(f"BM25 length: {sum(len(c) for c in bm25_content)} ({bm25_time:.3f}s)") + + +def markdown_params() -> list[ParameterSet]: + """Return a list of test parameters for the content filter tests.""" + params: list[ParameterSet] = [] + for name, html in TEST_HTML_SAMPLES.items(): + for gen_name, generator in GENERATORS.items(): + params.append(pytest.param(html, generator, id=f"{name}_{gen_name}")) + return params + +@pytest.mark.parametrize("html,generator", markdown_params()) +def test_markdown_generation(html: str, generator: DefaultMarkdownGenerator): """Test markdown generation with different configurations.""" - results = [] - - # Initialize generators with different configurations - generators = { - "no_filter": DefaultMarkdownGenerator(), - "pruning": DefaultMarkdownGenerator( - content_filter=PruningContentFilter(threshold=0.48) - ), - "bm25": DefaultMarkdownGenerator( - content_filter=BM25ContentFilter( - user_query="test article content important" - ) - ) - } - - # Test each generator with each HTML sample - for test_name, html in TEST_HTML_SAMPLES.items(): - for gen_name, generator in generators.items(): - start_time = time.time() - result = generator.generate_markdown( - html, - base_url="http://example.com", - citations=True - ) - - results.append({ - "test_case": test_name, - "generator": gen_name, - "time": time.time() - start_time, - "raw_length": len(result.raw_markdown), - "fit_length": len(result.fit_markdown) if result.fit_markdown else 0, - "citations": len(result.references_markdown) - }) - - return results - -def main(): - """Run all tests and print results.""" - print("Starting content filter tests...") - filter_results = test_content_filters() - - print("\nContent Filter Results:") - print("-" * 50) - for test_name, metrics in filter_results.items(): - print(f"\nTest case: {test_name}") - print(f"Original length: {metrics['original_length']}") - print(f"Pruned length: {metrics['pruned_length']} ({metrics['pruning_time']:.3f}s)") - print(f"BM25 length: {metrics['bm25_length']} ({metrics['bm25_time']:.3f}s)") - - print("\nStarting markdown generation tests...") - markdown_results = test_markdown_generation() - - print("\nMarkdown Generation Results:") - print("-" * 50) - for result in markdown_results: - print(f"\nTest: {result['test_case']} - Generator: {result['generator']}") - print(f"Time: {result['time']:.3f}s") - print(f"Raw length: {result['raw_length']}") - print(f"Fit length: {result['fit_length']}") - print(f"Citations: {result['citations']}") + + start_time = time.time() + result = generator.generate_markdown( + html, + base_url="http://example.com", + citations=True + ) + + assert result is not None + assert result.raw_markdown is not None + assert result.fit_markdown is not None + assert result.references_markdown is not None + + print(f"Time: {time.time() - start_time:.3f}s") + print(f"Raw length: {len(result.raw_markdown)}") + print(f"Fit length: {len(result.fit_markdown) if result.fit_markdown else 0}") + print(f"Citations: {len(result.references_markdown)}") if __name__ == "__main__": - main() \ No newline at end of file + import subprocess + + sys.exit(subprocess.call(["pytest", *sys.argv[1:], sys.argv[0]])) diff --git a/tests/20241401/test_async_webcrawler.py b/tests/20241401/test_async_webcrawler.py index 4d7aa815f..f9a2dd229 100644 --- a/tests/20241401/test_async_webcrawler.py +++ b/tests/20241401/test_async_webcrawler.py @@ -1,9 +1,10 @@ -import asyncio +import sys + import pytest -from typing import List + from crawl4ai import ( AsyncWebCrawler, - BrowserConfig, + BrowserConfig, CrawlerRunConfig, MemoryAdaptiveDispatcher, RateLimiter, @@ -30,7 +31,6 @@ async def test_viewport_config(viewport): result = await crawler.arun( url="https://example.com", config=CrawlerRunConfig( - # cache_mode=CacheMode.BYPASS, page_timeout=30000 # 30 seconds ) ) @@ -47,7 +47,7 @@ async def test_memory_management(): ) dispatcher = MemoryAdaptiveDispatcher( - memory_threshold_percent=70.0, + memory_threshold_percent=80.0, check_interval=1.0, max_session_permit=5 ) @@ -76,7 +76,7 @@ async def test_rate_limiting(): max_delay=5.0, max_retries=2 ), - memory_threshold_percent=70.0 + memory_threshold_percent=80.0, ) urls = [ @@ -143,7 +143,6 @@ async def test_error_handling(error_url): assert result.error_message is not None if __name__ == "__main__": - asyncio.run(test_viewport_config((1024, 768))) - asyncio.run(test_memory_management()) - asyncio.run(test_rate_limiting()) - asyncio.run(test_javascript_execution()) \ No newline at end of file + import subprocess + + sys.exit(subprocess.call(["pytest", *sys.argv[1:], sys.argv[0]])) diff --git a/tests/20241401/test_cache_context.py b/tests/20241401/test_cache_context.py index 0f42f9fdd..9bead1276 100644 --- a/tests/20241401/test_cache_context.py +++ b/tests/20241401/test_cache_context.py @@ -1,7 +1,13 @@ -import asyncio -from crawl4ai import AsyncWebCrawler, BrowserConfig, CrawlerRunConfig, CacheMode -from playwright.async_api import Page, BrowserContext +import sys +import pytest +from playwright.async_api import BrowserContext, Page + +from crawl4ai import AsyncWebCrawler, BrowserConfig, CacheMode, CrawlerRunConfig +from crawl4ai.models import CrawlResultContainer + + +@pytest.mark.asyncio async def test_reuse_context_by_config(): # We will store each context ID in these maps to confirm reuse context_ids_for_A = [] @@ -12,6 +18,7 @@ async def on_page_context_created(page: Page, context: BrowserContext, config: C c_id = id(context) print(f"[HOOK] on_page_context_created - Context ID: {c_id}") # Distinguish which config we used by checking a custom hook param + assert config.shared_data is not None config_label = config.shared_data.get("config_label", "unknown") if config_label == "A": context_ids_for_A.append(c_id) @@ -55,11 +62,13 @@ async def on_page_context_created(page: Page, context: BrowserContext, config: C print("\n--- Crawling with config A (text_mode=True) ---") for _ in range(2): # Pass an extra kwarg to the hook so we know which config is being used - await crawler.arun(test_url, config=configA) + result: CrawlResultContainer = await crawler.arun(test_url, config=configA) + assert result.success print("\n--- Crawling with config B (text_mode=False) ---") for _ in range(2): - await crawler.arun(test_url, config=configB) + result = await crawler.arun(test_url, config=configB) + assert result.success # Close the crawler (shuts down the browser, closes contexts) await crawler.close() @@ -68,18 +77,11 @@ async def on_page_context_created(page: Page, context: BrowserContext, config: C print("\n=== RESULTS ===") print(f"Config A context IDs: {context_ids_for_A}") print(f"Config B context IDs: {context_ids_for_B}") - if len(set(context_ids_for_A)) == 1: - print("✅ All config A crawls used the SAME BrowserContext.") - else: - print("❌ Config A crawls created multiple contexts unexpectedly.") - if len(set(context_ids_for_B)) == 1: - print("✅ All config B crawls used the SAME BrowserContext.") - else: - print("❌ Config B crawls created multiple contexts unexpectedly.") - if set(context_ids_for_A).isdisjoint(context_ids_for_B): - print("✅ Config A context is different from Config B context.") - else: - print("❌ A and B ended up sharing the same context somehow!") + assert len(set(context_ids_for_A)) == 1 + assert len(set(context_ids_for_B)) == 1 + assert set(context_ids_for_A).isdisjoint(context_ids_for_B) if __name__ == "__main__": - asyncio.run(test_reuse_context_by_config()) + import subprocess + + sys.exit(subprocess.call(["pytest", *sys.argv[1:], sys.argv[0]])) diff --git a/tests/20241401/test_crawlers.py b/tests/20241401/test_crawlers.py deleted file mode 100644 index 45fb8fcb3..000000000 --- a/tests/20241401/test_crawlers.py +++ /dev/null @@ -1,17 +0,0 @@ - -# example_usageexample_usageexample_usage# example_usage.py -import asyncio -from crawl4ai.crawlers import get_crawler - -async def main(): - # Get the registered crawler - example_crawler = get_crawler("example_site.content") - - # Crawl example.com - result = await example_crawler(url="https://example.com") - - print(result) - - -if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file diff --git a/tests/20241401/test_deep_crawl.py b/tests/20241401/test_deep_crawl.py index 2f533cc55..28cc15775 100644 --- a/tests/20241401/test_deep_crawl.py +++ b/tests/20241401/test_deep_crawl.py @@ -1,15 +1,51 @@ -import asyncio +import sys import time +from httpx import codes +import pytest from crawl4ai import CrawlerRunConfig, AsyncWebCrawler, CacheMode from crawl4ai.content_scraping_strategy import LXMLWebScrapingStrategy from crawl4ai.deep_crawling import BFSDeepCrawlStrategy -# from crawl4ai.deep_crawling import BFSDeepCrawlStrategy, BestFirstCrawlingStrategy +from pytest_httpserver import HTTPServer -async def main(): - """Example deep crawl of documentation site.""" + +URLS = [ + "/", + "/level1", + "/level2/article1", + "/level2/article2", +] + +@pytest.fixture +def site(httpserver: HTTPServer) -> HTTPServer: + """Fixture to serve multiple pages for a crawl.""" + httpserver.expect_request("/").respond_with_data(content_type="text/html", response_data=""" + + Go to level 1 + + """) + httpserver.expect_request("/level1").respond_with_data(content_type="text/html", response_data=""" + + Go to level 2 - Article 1 + Go to level 2 - Article 2 + + """) + httpserver.expect_request("/level2/article1").respond_with_data(content_type="text/html", response_data=""" + +

This is level 2 - Article 1

+ + """) + httpserver.expect_request("/level2/article2").respond_with_data(content_type="text/html", response_data=""" + +

This is level 2 - Article 2

+ + """) + return httpserver + +@pytest.mark.asyncio +async def test_deep_crawl_batch(site: HTTPServer): config = CrawlerRunConfig( deep_crawl_strategy = BFSDeepCrawlStrategy( max_depth=2, @@ -25,22 +61,52 @@ async def main(): start_time = time.perf_counter() print("\nStarting deep crawl in batch mode:") results = await crawler.arun( - url="https://docs.crawl4ai.com", + url=site.url_for("/"), config=config ) print(f"Crawled {len(results)} pages") print(f"Example page: {results[0].url}") print(f"Duration: {time.perf_counter() - start_time:.2f} seconds\n") + assert len(results) == len(URLS) + for idx, result in enumerate(results): + assert result.url == site.url_for(URLS[idx]) + assert result.status_code == codes.OK + +@pytest.mark.asyncio +async def test_deep_crawl_stream(site: HTTPServer): + config = CrawlerRunConfig( + deep_crawl_strategy = BFSDeepCrawlStrategy( + max_depth=2, + include_external=False + ), + stream=True, + verbose=True, + cache_mode=CacheMode.BYPASS, + scraping_strategy=LXMLWebScrapingStrategy() + ) + + async with AsyncWebCrawler() as crawler: print("Starting deep crawl in streaming mode:") - config.stream = True start_time = time.perf_counter() + last_time = start_time + idx = 0 async for result in await crawler.arun( - url="https://docs.crawl4ai.com", + url=site.url_for("/"), config=config ): - print(f"→ {result.url} (Depth: {result.metadata.get('depth', 0)})") + now = time.perf_counter() + duration = now - last_time + last_time = now + assert result.status_code == codes.OK + assert result.url == site.url_for(URLS[idx]) + assert result.metadata + print(f"→ {result.url} (Depth: {result.metadata.get('depth', 0)}) ({duration:.2f} seconds)") + idx += 1 + print(f"Crawled {idx} pages") print(f"Duration: {time.perf_counter() - start_time:.2f} seconds") if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file + import subprocess + + sys.exit(subprocess.call(["pytest", *sys.argv[1:], sys.argv[0]])) diff --git a/tests/20241401/test_deep_crawl_filters.py b/tests/20241401/test_deep_crawl_filters.py index 948bbcbdc..f5a7a0659 100644 --- a/tests/20241401/test_deep_crawl_filters.py +++ b/tests/20241401/test_deep_crawl_filters.py @@ -1,279 +1,259 @@ -from crawl4ai.deep_crawling.filters import ContentRelevanceFilter, URLPatternFilter, DomainFilter, ContentTypeFilter, SEOFilter -async def test_pattern_filter(): - # Test cases as list of tuples instead of dict for multiple patterns - test_cases = [ - # Simple suffix patterns (*.html) - ("*.html", { +import sys + +from typing import List, Union +from _pytest.mark.structures import ParameterSet +import pytest +from re import Pattern + +from crawl4ai.deep_crawling.filters import ( + ContentRelevanceFilter, + ContentTypeFilter, + DomainFilter, + SEOFilter, + URLPatternFilter, +) + + +def pattern_filter_params() -> list[ParameterSet]: + params: list[ParameterSet] = [] + def build_params( + name: str, + pattern: Union[str, Pattern, List[Union[str, Pattern]]], + test_urls: dict[str, bool], + ): + for url, expected in test_urls.items(): + params.append(pytest.param(pattern, url, expected, id=f"{name} {url}")) + + # Simple suffix patterns (*.html) + build_params( + "simple-suffix", + "*.html", { "https://example.com/page.html": True, "https://example.com/path/doc.html": True, "https://example.com/page.htm": False, "https://example.com/page.html?param=1": True, - }), - - # Path prefix patterns (/foo/*) - ("*/article/*", { + }, + ) + + # Path prefix patterns (/foo/*) + build_params( + "path-prefix", + "*/article/*", { "https://example.com/article/123": True, "https://example.com/blog/article/456": True, "https://example.com/articles/789": False, "https://example.com/article": False, - }), - - # Complex patterns - ("blog-*-[0-9]", { + } + ) + + # Complex patterns + build_params( + "complex-pattern", + "blog-*-[0-9]", { "https://example.com/blog-post-1": True, "https://example.com/blog-test-9": True, "https://example.com/blog-post": False, "https://example.com/blog-post-x": False, - }), - - # Multiple patterns case - (["*.pdf", "*/download/*"], { + } + ) + + # Multiple patterns case + build_params( + "multiple-patterns", + ["*.pdf", "*/download/*"], { "https://example.com/doc.pdf": True, "https://example.com/download/file.txt": True, "https://example.com/path/download/doc": True, "https://example.com/uploads/file.txt": False, - }), - - # Edge cases - ("*", { + } + ) + + # Edge cases + build_params( + "edge-cases", + "*", { "https://example.com": True, "": True, "http://test.com/path": True, - }), - - # Complex regex - (r"^https?://.*\.example\.com/\d+", { + }, + ) + + # Complex regex + build_params( + "complex-regex", + r"^https?://.*\.example\.com/\d+", { "https://sub.example.com/123": True, "http://test.example.com/456": True, "https://example.com/789": False, "https://sub.example.com/abc": False, - }) - ] - - def run_accuracy_test(): - print("\nAccuracy Tests:") - print("-" * 50) - - all_passed = True - for patterns, test_urls in test_cases: - filter_obj = URLPatternFilter(patterns) - - for url, expected in test_urls.items(): - result = filter_obj.apply(url) - if result != expected: - print(f"❌ Failed: Pattern '{patterns}' with URL '{url}'") - print(f" Expected: {expected}, Got: {result}") - all_passed = False - else: - print(f"✅ Passed: Pattern '{patterns}' with URL '{url}'") - - return all_passed - - # Run tests - print("Running Pattern Filter Tests...") - accuracy_passed = run_accuracy_test() - - if accuracy_passed: - print("\n✨ All accuracy tests passed!") - - else: - print("\n❌ Some accuracy tests failed!") - -async def test_domain_filter(): - from itertools import chain - - # Test cases - test_cases = [ - # Allowed domains - ({"allowed": "example.com"}, { + }, + ) + return params + +@pytest.mark.asyncio +@pytest.mark.parametrize("pattern,url,expected", pattern_filter_params()) +async def test_pattern_filter(pattern: Union[str, Pattern, List[Union[str, Pattern]]], url: str, expected: bool): + filter_obj = URLPatternFilter(pattern) + result = filter_obj.apply(url) + assert result == expected + + +def domain_filter_params() -> list[ParameterSet]: + params: list[ParameterSet] = [] + def build_params( + name: str, + filter: DomainFilter, + test_urls: dict[str, bool], + ): + for url, expected in test_urls.items(): + params.append(pytest.param(filter, url, expected, id=f"{name} {url}")) + + # Allowed domains + build_params( + "allowed-domains", + DomainFilter(allowed_domains="example.com"), { "https://example.com/page": True, "http://example.com": True, - "https://sub.example.com": False, + "https://sub.example.com": True, "https://other.com": False, - }), + } + ) - ({"allowed": ["example.com", "test.com"]}, { + build_params( + "allowed-domains-list", + DomainFilter(allowed_domains=["example.com", "test.com"]), { "https://example.com/page": True, "https://test.com/home": True, "https://other.com": False, - }), + } + ) - # Blocked domains - ({"blocked": "malicious.com"}, { + # Blocked domains + build_params( + "blocked-domains", + DomainFilter(blocked_domains="malicious.com"), { "https://malicious.com": False, "https://safe.com": True, "http://malicious.com/login": False, - }), + } + ) - ({"blocked": ["spam.com", "ads.com"]}, { + build_params( + "blocked-domains-list", + DomainFilter(blocked_domains=["spam.com", "ads.com"]), { "https://spam.com": False, "https://ads.com/banner": False, "https://example.com": True, - }), + } + ) - # Allowed and Blocked combination - ({"allowed": "example.com", "blocked": "sub.example.com"}, { + # Allowed and Blocked combination + build_params( + "allowed-and-blocked", + DomainFilter( + allowed_domains="example.com", + blocked_domains="sub.example.com" + ), { "https://example.com": True, "https://sub.example.com": False, "https://other.com": False, - }), - ] - - def run_accuracy_test(): - print("\nAccuracy Tests:") - print("-" * 50) - - all_passed = True - for params, test_urls in test_cases: - filter_obj = DomainFilter( - allowed_domains=params.get("allowed"), - blocked_domains=params.get("blocked"), - ) - - for url, expected in test_urls.items(): - result = filter_obj.apply(url) - if result != expected: - print(f"\u274C Failed: Params {params} with URL '{url}'") - print(f" Expected: {expected}, Got: {result}") - all_passed = False - else: - print(f"\u2705 Passed: Params {params} with URL '{url}'") - - return all_passed - - # Run tests - print("Running Domain Filter Tests...") - accuracy_passed = run_accuracy_test() - - if accuracy_passed: - print("\n\u2728 All accuracy tests passed!") - else: - print("\n\u274C Some accuracy tests failed!") - -async def test_content_relevance_filter(): + } + ) + + return params + +@pytest.mark.asyncio +@pytest.mark.parametrize("filter,url,expected", domain_filter_params()) +async def test_domain_filter(filter: DomainFilter, url: str, expected: bool): + result = filter.apply(url) + assert result == expected + + +@pytest.mark.asyncio +@pytest.mark.parametrize("url,expected", [ + ("https://en.wikipedia.org/wiki/Cricket", False), + ("https://en.wikipedia.org/wiki/American_Civil_War", True), +]) +async def test_content_relevance_filter(url: str, expected: bool): relevance_filter = ContentRelevanceFilter( - query="What was the cause of american civil war?", - threshold=1 + query="What was the cause of american civil war?", threshold=1 ) - test_cases = { - "https://en.wikipedia.org/wiki/Cricket": False, - "https://en.wikipedia.org/wiki/American_Civil_War": True, - } - - print("\nRunning Content Relevance Filter Tests...") - print("-" * 50) - - all_passed = True - for url, expected in test_cases.items(): - result = await relevance_filter.apply(url) - if result != expected: - print(f"\u274C Failed: URL '{url}'") - print(f" Expected: {expected}, Got: {result}") - all_passed = False - else: - print(f"\u2705 Passed: URL '{url}'") - - if all_passed: - print("\n\u2728 All content relevance tests passed!") - else: - print("\n\u274C Some content relevance tests failed!") - -async def test_content_type_filter(): - from itertools import chain - - # Test cases - test_cases = [ - # Allowed single type - ({"allowed": "image/png"}, { + result = await relevance_filter.apply(url) + assert result == expected + + +def content_type_filter_params() -> list[ParameterSet]: + params: list[ParameterSet] = [] + def build_params( + name: str, + filter: ContentTypeFilter, + test_urls: dict[str, bool], + ): + for url, expected in test_urls.items(): + params.append(pytest.param(filter, url, expected, id=f"{name} {url}")) + + # Allowed single type + build_params( + "content-type-filter", + ContentTypeFilter(allowed_types="image/png"), { "https://example.com/image.png": True, "https://example.com/photo.jpg": False, "https://example.com/document.pdf": False, - }), + }, + ) - # Multiple allowed types - ({"allowed": ["image/jpeg", "application/pdf"]}, { + # Multiple allowed types + build_params( + "multiple-content-types", + ContentTypeFilter(allowed_types=["image/jpeg", "application/pdf"]), { "https://example.com/photo.jpg": True, "https://example.com/document.pdf": True, "https://example.com/script.js": False, - }), + } + ) - # No extension should be allowed - ({"allowed": "application/json"}, { + # No extension should be allowed + build_params( + "no-extension-allowed", + ContentTypeFilter(allowed_types="application/json"), { "https://example.com/api/data": True, "https://example.com/data.json": True, "https://example.com/page.html": False, - }), + } + ) - # Unknown extensions should not be allowed - ({"allowed": "application/octet-stream"}, { + # Unknown extensions should not be allowed + build_params( + "unknown-extension-not-allowed", + ContentTypeFilter(allowed_types="application/octet-stream"), { "https://example.com/file.unknown": True, "https://example.com/archive.zip": False, "https://example.com/software.exe": False, - }), - ] - - def run_accuracy_test(): - print("\nAccuracy Tests:") - print("-" * 50) - - all_passed = True - for params, test_urls in test_cases: - filter_obj = ContentTypeFilter( - allowed_types=params.get("allowed"), - ) - - for url, expected in test_urls.items(): - result = filter_obj.apply(url) - if result != expected: - print(f"\u274C Failed: Params {params} with URL '{url}'") - print(f" Expected: {expected}, Got: {result}") - all_passed = False - else: - print(f"\u2705 Passed: Params {params} with URL '{url}'") - - return all_passed - - # Run tests - print("Running Content Type Filter Tests...") - accuracy_passed = run_accuracy_test() - - if accuracy_passed: - print("\n\u2728 All accuracy tests passed!") - else: - print("\n\u274C Some accuracy tests failed!") - -async def test_seo_filter(): - seo_filter = SEOFilter(threshold=0.5, keywords=["SEO", "search engines", "Optimization"]) - - test_cases = { - "https://en.wikipedia.org/wiki/Search_engine_optimization": True, - "https://en.wikipedia.org/wiki/Randomness": False, - } - - print("\nRunning SEO Filter Tests...") - print("-" * 50) - - all_passed = True - for url, expected in test_cases.items(): - result = await seo_filter.apply(url) - if result != expected: - print(f"\u274C Failed: URL '{url}'") - print(f" Expected: {expected}, Got: {result}") - all_passed = False - else: - print(f"\u2705 Passed: URL '{url}'") - - if all_passed: - print("\n\u2728 All SEO filter tests passed!") - else: - print("\n\u274C Some SEO filter tests failed!") - -import asyncio + } + ) + + return params + +@pytest.mark.asyncio +@pytest.mark.parametrize("filter,url,expected", content_type_filter_params()) +async def test_content_type_filter(filter: ContentTypeFilter, url:str, expected: bool): + result = filter.apply(url) + assert result == expected, f"URL: {url}, Expected: {expected}, Got: {result}" + +@pytest.mark.asyncio +@pytest.mark.parametrize("url,expected", [ + ("https://en.wikipedia.org/wiki/Search_engine_optimization", True), + ("https://en.wikipedia.org/wiki/Randomness", False), +]) +async def test_seo_filter(url: str, expected: bool): + seo_filter = SEOFilter( + threshold=0.5, keywords=["SEO", "search engines", "Optimization"] + ) + result = await seo_filter.apply(url) + assert result == expected if __name__ == "__main__": - asyncio.run(test_pattern_filter()) - asyncio.run(test_domain_filter()) - asyncio.run(test_content_type_filter()) - asyncio.run(test_content_relevance_filter()) - asyncio.run(test_seo_filter()) \ No newline at end of file + import subprocess + + sys.exit(subprocess.call(["pytest", *sys.argv[1:], sys.argv[0]])) diff --git a/tests/20241401/test_deep_crawl_scorers.py b/tests/20241401/test_deep_crawl_scorers.py index 8e68bca6b..d3bffc6ed 100644 --- a/tests/20241401/test_deep_crawl_scorers.py +++ b/tests/20241401/test_deep_crawl_scorers.py @@ -1,179 +1,129 @@ -from crawl4ai.deep_crawling.scorers import CompositeScorer, ContentTypeScorer, DomainAuthorityScorer, FreshnessScorer, KeywordRelevanceScorer, PathDepthScorer +import sys +import pytest +from _pytest.mark.structures import ParameterSet -def test_scorers(): - test_cases = [ - # Keyword Scorer Tests +from crawl4ai.deep_crawling.scorers import ( + CompositeScorer, + ContentTypeScorer, + DomainAuthorityScorer, + FreshnessScorer, + KeywordRelevanceScorer, + PathDepthScorer, + URLScorer, +) + + +def scorers_params() -> list[ParameterSet]: + tests: list[ParameterSet] = [] + + def add_tests(name: str, scorer, urls: dict[str, float]): + for url, expected in urls.items(): + tests.append(pytest.param(scorer, url, expected, id=f"{name} {url}")) + + # Keyword Scorer Tests + add_tests( + "keyword-scorer", + KeywordRelevanceScorer(keywords=["python", "blog"], weight=1.0, case_sensitive=False), { - "scorer_type": "keyword", - "config": { - "keywords": ["python", "blog"], - "weight": 1.0, - "case_sensitive": False - }, - "urls": { - "https://example.com/python-blog": 1.0, - "https://example.com/PYTHON-BLOG": 1.0, - "https://example.com/python-only": 0.5, - "https://example.com/other": 0.0 - } + "https://example.com/python-blog": 1.0, + "https://example.com/PYTHON-BLOG": 1.0, + "https://example.com/python-only": 0.5, + "https://example.com/other": 0.0, }, - - # Path Depth Scorer Tests + ) + + # Path Depth Scorer Tests + add_tests( + "path-depth-scorer", + PathDepthScorer(optimal_depth=2, weight=1.0), { - "scorer_type": "path_depth", - "config": { - "optimal_depth": 2, - "weight": 1.0 - }, - "urls": { - "https://example.com/a/b": 1.0, - "https://example.com/a": 0.5, - "https://example.com/a/b/c": 0.5, - "https://example.com": 0.33333333 - } + "https://example.com/a/b": 1.0, + "https://example.com/a": 0.5, + "https://example.com/a/b/c": 0.5, + "https://example.com": 0.33333333, }, - - # Content Type Scorer Tests + ) + + # Content Type Scorer Tests + add_tests( + "content-type-scorer", + ContentTypeScorer(type_weights={".html$": 1.0, ".pdf$": 0.8, ".jpg$": 0.6}, weight=1.0), { - "scorer_type": "content_type", - "config": { - "type_weights": { - ".html$": 1.0, - ".pdf$": 0.8, - ".jpg$": 0.6 - }, - "weight": 1.0 - }, - "urls": { - "https://example.com/doc.html": 1.0, - "https://example.com/doc.pdf": 0.8, - "https://example.com/img.jpg": 0.6, - "https://example.com/other.txt": 0.0 - } + "https://example.com/doc.html": 1.0, + "https://example.com/doc.pdf": 0.8, + "https://example.com/img.jpg": 0.6, + "https://example.com/other.txt": 0.0, }, - - # Freshness Scorer Tests + ) + + # Freshness Scorer Tests + add_tests( + "freshness-scorer", + FreshnessScorer(weight=1.0, current_year=2024), { - "scorer_type": "freshness", - "config": { - "weight": 1.0, # Remove current_year since original doesn't support it - }, - "urls": { - "https://example.com/2024/01/post": 1.0, - "https://example.com/2023/12/post": 0.9, - "https://example.com/2022/post": 0.8, - "https://example.com/no-date": 0.5 - } + "https://example.com/2024/01/post": 1.0, + "https://example.com/2023/12/post": 0.9, + "https://example.com/2022/post": 0.8, + "https://example.com/no-date": 0.5, }, - - # Domain Authority Scorer Tests + ) + + # Domain Authority Scorer Tests + add_tests( + "domain-authority-scorer", + DomainAuthorityScorer( + domain_weights={"python.org": 1.0, "github.com": 0.8, "medium.com": 0.6}, default_weight=0.3, weight=1.0 + ), { - "scorer_type": "domain", - "config": { - "domain_weights": { - "python.org": 1.0, - "github.com": 0.8, - "medium.com": 0.6 - }, - "default_weight": 0.3, - "weight": 1.0 - }, - "urls": { - "https://python.org/about": 1.0, - "https://github.com/repo": 0.8, - "https://medium.com/post": 0.6, - "https://unknown.com": 0.3 - } - } - ] - - def create_scorer(scorer_type, config): - if scorer_type == "keyword": - return KeywordRelevanceScorer(**config) - elif scorer_type == "path_depth": - return PathDepthScorer(**config) - elif scorer_type == "content_type": - return ContentTypeScorer(**config) - elif scorer_type == "freshness": - return FreshnessScorer(**config,current_year=2024) - elif scorer_type == "domain": - return DomainAuthorityScorer(**config) - - def run_accuracy_test(): - print("\nAccuracy Tests:") - print("-" * 50) - - all_passed = True - for test_case in test_cases: - print(f"\nTesting {test_case['scorer_type']} scorer:") - scorer = create_scorer( - test_case['scorer_type'], - test_case['config'] - ) - - for url, expected in test_case['urls'].items(): - score = round(scorer.score(url), 8) - expected = round(expected, 8) - - if abs(score - expected) > 0.00001: - print(f"❌ Scorer Failed: URL '{url}'") - print(f" Expected: {expected}, Got: {score}") - all_passed = False - else: - print(f"✅ Scorer Passed: URL '{url}'") - - - return all_passed - - def run_composite_test(): - print("\nTesting Composite Scorer:") - print("-" * 50) - - # Create test data - test_urls = { - "https://python.org/blog/2024/01/new-release.html":0.86666667, - "https://github.com/repo/old-code.pdf": 0.62, - "https://unknown.com/random": 0.26 - } - - # Create composite scorers with all types - scorers = [] - - for test_case in test_cases: - scorer = create_scorer( - test_case['scorer_type'], - test_case['config'] - ) - scorers.append(scorer) - - composite = CompositeScorer(scorers, normalize=True) - - all_passed = True - for url, expected in test_urls.items(): - score = round(composite.score(url), 8) - - if abs(score - expected) > 0.00001: - print(f"❌ Composite Failed: URL '{url}'") - print(f" Expected: {expected}, Got: {score}") - all_passed = False - else: - print(f"✅ Composite Passed: URL '{url}'") - - return all_passed - - # Run tests - print("Running Scorer Tests...") - accuracy_passed = run_accuracy_test() - composite_passed = run_composite_test() - - if accuracy_passed and composite_passed: - print("\n✨ All tests passed!") - # Note: Already have performance tests in run_scorer_performance_test() - else: - print("\n❌ Some tests failed!") - - + "https://python.org/about": 1.0, + "https://github.com/repo": 0.8, + "https://medium.com/post": 0.6, + "https://unknown.com": 0.3, + }, + ) + + return tests + + +@pytest.mark.parametrize("scorer,url,expected", scorers_params()) +def test_accuracy(scorer: URLScorer, url: str, expected: float): + score = round(scorer.score(url), 8) + expected = round(expected, 8) + + assert abs(score - expected) < 0.00001, f"Expected: {expected}, Got: {score}" + + +def composite_scorer_params() -> list[ParameterSet]: + composite = CompositeScorer( + [ + KeywordRelevanceScorer(keywords=["python", "blog"], weight=1.0, case_sensitive=False), + PathDepthScorer(optimal_depth=2, weight=1.0), + ContentTypeScorer(type_weights={".html$": 1.0, ".pdf$": 0.8, ".jpg$": 0.6}, weight=1.0), + FreshnessScorer(weight=1.0, current_year=2024), + DomainAuthorityScorer( + domain_weights={"python.org": 1.0, "github.com": 0.8, "medium.com": 0.6}, default_weight=0.3, weight=1.0 + ), + ], + normalize=True, + ) + + test_urls = { + "https://python.org/blog/2024/01/new-release.html": 0.86666667, + "https://github.com/repo/old-code.pdf": 0.62, + "https://unknown.com/random": 0.26, + } + + return [pytest.param(url, expected, composite, id=url) for url, expected in test_urls.items()] + + +@pytest.mark.parametrize("url,expected,composite", composite_scorer_params()) +def test_composite(url: str, expected: float, composite: CompositeScorer): + score = round(composite.score(url), 8) + assert abs(score - expected) < 0.00001, f"Expected: {expected}, Got: {score}" + if __name__ == "__main__": - test_scorers() \ No newline at end of file + import subprocess + + sys.exit(subprocess.call(["pytest", *sys.argv[1:], sys.argv[0]])) diff --git a/tests/20241401/test_http_crawler_strategy.py b/tests/20241401/test_http_crawler_strategy.py index dc1414188..175f29aca 100644 --- a/tests/20241401/test_http_crawler_strategy.py +++ b/tests/20241401/test_http_crawler_strategy.py @@ -1,116 +1,131 @@ -from tkinter import N -from crawl4ai.async_crawler_strategy import AsyncHTTPCrawlerStrategy +import sys +from crawl4ai.async_crawler_strategy import AsyncHTTPCrawlerStrategy, ConnectionTimeoutError from crawl4ai.async_logger import AsyncLogger from crawl4ai import CrawlerRunConfig, HTTPCrawlerConfig -from crawl4ai.async_crawler_strategy import ConnectionTimeoutError -import asyncio -import os +import pytest +import pytest_asyncio +from httpx import codes -async def main(): - """Test the AsyncHTTPCrawlerStrategy with various scenarios""" - logger = AsyncLogger(verbose=True) - - # Initialize the strategy with default HTTPCrawlerConfig - crawler = AsyncHTTPCrawlerStrategy( +@pytest_asyncio.fixture +async def crawler(): + async with AsyncHTTPCrawlerStrategy( browser_config=HTTPCrawlerConfig(), - logger=logger - ) - # Test 1: Basic HTTP GET - print("\n=== Test 1: Basic HTTP GET ===") + logger=AsyncLogger(verbose=True) + ) as crawler: + yield crawler + +@pytest.mark.asyncio +async def test_basic_get(crawler: AsyncHTTPCrawlerStrategy): result = await crawler.crawl("https://example.com") + assert result.status_code == codes.OK + assert result.html + assert result.response_headers print(f"Status: {result.status_code}") print(f"Content length: {len(result.html)}") print(f"Headers: {dict(result.response_headers)}") - # Test 2: POST request with JSON - print("\n=== Test 2: POST with JSON ===") +@pytest.mark.asyncio +async def test_post_with_json(crawler: AsyncHTTPCrawlerStrategy): crawler.browser_config = crawler.browser_config.clone( method="POST", json={"test": "data"}, headers={"Content-Type": "application/json"} ) - try: - result = await crawler.crawl( - "https://httpbin.org/post", - ) - print(f"Status: {result.status_code}") - print(f"Response: {result.html[:200]}...") - except Exception as e: - print(f"Error: {e}") + result = await crawler.crawl( + "https://httpbin.org/post", + ) + assert result.status_code == codes.OK + assert result.html + print(f"Response: {result.html[:200]}...") - # Test 3: File handling +@pytest.mark.asyncio +async def test_file_handling(crawler: AsyncHTTPCrawlerStrategy): crawler.browser_config = HTTPCrawlerConfig() - print("\n=== Test 3: Local file handling ===") # Create a tmp file with test content from tempfile import NamedTemporaryFile with NamedTemporaryFile(delete=False) as f: f.write(b"Test content") f.close() result = await crawler.crawl(f"file://{f.name}") - print(f"File content: {result.html}") + assert result.status_code == codes.OK + assert result.html == "Test content" - # Test 4: Raw content - print("\n=== Test 4: Raw content handling ===") +@pytest.mark.asyncio +async def test_raw_content(crawler: AsyncHTTPCrawlerStrategy): raw_html = "raw://Raw test content" result = await crawler.crawl(raw_html) - print(f"Raw content: {result.html}") + assert result.status_code == codes.OK + assert result.html == "Raw test content" - # Test 5: Custom hooks - print("\n=== Test 5: Custom hooks ===") +@pytest.mark.asyncio +async def test_custom_hooks(crawler: AsyncHTTPCrawlerStrategy): + before_called: bool = False async def before_request(url, kwargs): print(f"Before request to {url}") kwargs['headers']['X-Custom'] = 'test' + nonlocal before_called + before_called = True + after_called: bool = False async def after_request(response): print(f"After request, status: {response.status_code}") + nonlocal after_called + after_called = True crawler.set_hook('before_request', before_request) crawler.set_hook('after_request', after_request) result = await crawler.crawl("https://example.com") + assert result.status_code == codes.OK + assert result.html + assert before_called + assert after_called - # Test 6: Error handling - print("\n=== Test 6: Error handling ===") - try: +@pytest.mark.asyncio +async def test_error_handling(crawler: AsyncHTTPCrawlerStrategy): + with pytest.raises(ConnectionError): await crawler.crawl("https://nonexistent.domain.test") - except Exception as e: - print(f"Expected error: {e}") - # Test 7: Redirects - print("\n=== Test 7: Redirect handling ===") +@pytest.mark.asyncio +async def test_redirects(crawler: AsyncHTTPCrawlerStrategy): crawler.browser_config = HTTPCrawlerConfig(follow_redirects=True) result = await crawler.crawl("http://httpbin.org/redirect/1") - print(f"Final URL: {result.redirected_url}") + assert result.status_code == codes.OK + assert result.redirected_url == "http://httpbin.org/get" - # Test 8: Custom timeout +@pytest.mark.asyncio +async def test_custom_timeout(crawler: AsyncHTTPCrawlerStrategy): print("\n=== Test 8: Custom timeout ===") - try: + with pytest.raises(ConnectionTimeoutError): await crawler.crawl( "https://httpbin.org/delay/5", config=CrawlerRunConfig(page_timeout=2) ) - except ConnectionTimeoutError as e: - print(f"Expected timeout: {e}") - # Test 9: SSL verification - print("\n=== Test 9: SSL verification ===") +@pytest.mark.asyncio +async def test_ssl_verify_off(crawler: AsyncHTTPCrawlerStrategy): crawler.browser_config = HTTPCrawlerConfig(verify_ssl=False) - try: + result = await crawler.crawl("https://expired.badssl.com/") + assert result.status_code == codes.OK + +@pytest.mark.asyncio +async def test_ssl_verify_on(crawler: AsyncHTTPCrawlerStrategy): + with pytest.raises(ConnectionError): + crawler.browser_config = HTTPCrawlerConfig() await crawler.crawl("https://expired.badssl.com/") - print("Connected to invalid SSL site with verification disabled") - except Exception as e: - print(f"SSL error: {e}") - # Test 10: Large file streaming - print("\n=== Test 10: Large file streaming ===") +@pytest.mark.asyncio +async def test_large_file_streaming(crawler: AsyncHTTPCrawlerStrategy): from tempfile import NamedTemporaryFile - with NamedTemporaryFile(delete=False) as f: + with NamedTemporaryFile() as f: f.write(b"" + b"X" * 1024 * 1024 * 10 + b"") - f.close() + f.flush() + size: int = f.tell() result = await crawler.crawl("file://" + f.name) - print(f"Large file content length: {len(result.html)}") - os.remove(f.name) - - crawler.close() + assert result.status_code == codes.OK + assert len(result.html) == size + f.close() if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file + import subprocess + + sys.exit(subprocess.call(["pytest", *sys.argv[1:], sys.argv[0]])) diff --git a/tests/20241401/test_llm_filter.py b/tests/20241401/test_llm_filter.py index 6211c4295..b82695760 100644 --- a/tests/20241401/test_llm_filter.py +++ b/tests/20241401/test_llm_filter.py @@ -1,10 +1,16 @@ import os -import asyncio -from crawl4ai import AsyncWebCrawler, BrowserConfig, CrawlerRunConfig, CacheMode -from crawl4ai import LLMConfig +import sys +import pytest from crawl4ai.content_filter_strategy import LLMContentFilter +from crawl4ai import AsyncWebCrawler, BrowserConfig, CrawlerRunConfig, CacheMode +from crawl4ai.async_configs import LLMConfig +@pytest.mark.asyncio +@pytest.mark.timeout(200) async def test_llm_filter(): + if not os.getenv("OPENAI_API_KEY"): + pytest.skip("Skipping env OPENAI_API_KEY not set") + # Create an HTML source that needs intelligent filtering url = "https://docs.python.org/3/tutorial/classes.html" @@ -12,14 +18,15 @@ async def test_llm_filter(): headless=True, verbose=True ) - - # run_config = CrawlerRunConfig(cache_mode=CacheMode.BYPASS) + run_config = CrawlerRunConfig(cache_mode=CacheMode.ENABLED) async with AsyncWebCrawler(config=browser_config) as crawler: # First get the raw HTML result = await crawler.arun(url, config=run_config) + assert result.success html = result.cleaned_html + assert html # Initialize LLM filter with focused instruction filter = LLMContentFilter( @@ -59,28 +66,32 @@ async def test_llm_filter(): - Sidebars with external links - Any UI elements that don't contribute to learning - The goal is to create a clean markdown version that reads exactly like the original article, - keeping all valuable content but free from distracting elements. Imagine you're creating + The goal is to create a clean markdown version that reads exactly like the original article, + keeping all valuable content but free from distracting elements. Imagine you're creating a perfect reading experience where nothing valuable is lost, but all noise is removed. """, - verbose=True - ) + verbose=True, + ) # Apply filtering filtered_content = filter.filter_content(html, ignore_cache = True) - + assert filtered_content + # Show results print("\nFiltered Content Length:", len(filtered_content)) print("\nFirst 500 chars of filtered content:") if filtered_content: print(filtered_content[0][:500]) - + # Save on disc the markdown version with open("filtered_content.md", "w", encoding="utf-8") as f: f.write("\n".join(filtered_content)) - + # Show token usage filter.show_usage() + if __name__ == "__main__": - asyncio.run(test_llm_filter()) \ No newline at end of file + import subprocess + + sys.exit(subprocess.call(["pytest", *sys.argv[1:], sys.argv[0]])) diff --git a/tests/20241401/test_robot.py b/tests/20241401/test_robot.py new file mode 100644 index 000000000..081990e7e --- /dev/null +++ b/tests/20241401/test_robot.py @@ -0,0 +1,61 @@ +import sys + +import pytest +from _pytest.mark.structures import ParameterSet + +from crawl4ai import BrowserConfig, CacheMode, CrawlerRunConfig +from crawl4ai.async_webcrawler import AsyncWebCrawler + +TEST_CASES = [ + # Public sites that should be allowed + ("https://example.com", True), # Simple public site + ("https://httpbin.org/get", True), # API endpoint + # Sites with known strict robots.txt + ("https://www.facebook.com/robots.txt", False), # Social media + ("https://www.google.com/search", False), # Search pages + # Edge cases + ("https://api.github.com", True), # API service + ("https://raw.githubusercontent.com", True), # Content delivery + # Non-existent/error cases + ("https://thisisnotarealwebsite123.com", False), # Non-existent domain + ("https://localhost:12345", False), # Invalid port +] + + +def website_params() -> list[ParameterSet]: + return [pytest.param(url, expected, id=url) for url, expected in TEST_CASES] + + +@pytest.mark.asyncio +@pytest.mark.parametrize("url, expected", website_params()) +async def test_real_websites(url: str, expected: bool): + print("\n=== Testing Real Website Robots.txt Compliance ===\n") + + browser_config = BrowserConfig(headless=True, verbose=True) + async with AsyncWebCrawler(config=browser_config) as crawler: + config = CrawlerRunConfig( + cache_mode=CacheMode.BYPASS, + check_robots_txt=True, # Enable robots.txt checking + verbose=True, + ) + + result = await crawler.arun(url=url, config=config) + allowed = result.success and not result.error_message + + print(f"Expected: {'allowed' if expected else 'denied'}") + print(f"Actual: {'allowed' if allowed else 'denied'}") + print(f"Status Code: {result.status_code}") + if result.error_message: + print(f"Error: {result.error_message}") + + assert expected == allowed, f"Expected {expected} but got {allowed} for {url}" + + # Optional: Print robots.txt content if available + if result.metadata and 'robots_txt' in result.metadata: + print(f"Robots.txt rules:\n{result.metadata['robots_txt']}") + + +if __name__ == "__main__": + import subprocess + + sys.exit(subprocess.call(["pytest", *sys.argv[1:], sys.argv[0]])) diff --git a/tests/20241401/test_robot_parser.py b/tests/20241401/test_robot_parser.py index a2fc30f1a..8fa793491 100644 --- a/tests/20241401/test_robot_parser.py +++ b/tests/20241401/test_robot_parser.py @@ -1,159 +1,163 @@ -from crawl4ai.utils import RobotsParser - import asyncio -import aiohttp +import os +from pathlib import Path +import sys +import time + +import pytest from aiohttp import web -import tempfile -import shutil -import os, sys, time, json +from crawl4ai.utils import RobotsParser -async def test_robots_parser(): + +@pytest.mark.asyncio +async def test_robots_parser(tmp_path: Path ): print("\n=== Testing RobotsParser ===\n") - - # Setup temporary directory for testing - temp_dir = tempfile.mkdtemp() + temp_dir = tmp_path.as_posix() + # 1. Basic setup test + print("1. Testing basic initialization...") + parser = RobotsParser(cache_dir=temp_dir) + assert os.path.exists(parser.db_path), "Database file not created" + print("✓ Basic initialization passed") + + # 2. Test common cases + print("\n2. Testing common cases...") + start = time.time() + allowed = await parser.can_fetch("https://httpbin.org", "MyBot/1.0") + uncached_duration: float = time.time() - start + assert allowed + print(f"✓ Regular website fetch: {'allowed' if allowed else 'denied'} took: {uncached_duration * 1000:.2f}ms") + + # Test caching + print("Testing cache...") + start = time.time() + await parser.can_fetch("https://httpbin.org", "MyBot/1.0") + duration = time.time() - start + print(f"✓ Cached lookup took: {duration * 1000:.2f}ms") + # Using a hardcoded threshold results in flaky tests so + # we just check that the cached lookup is faster than the uncached one. + assert duration < uncached_duration, "Cache lookup too slow" # + + # 3. Edge cases + print("\n3. Testing edge cases...") + + # Empty URL + result = await parser.can_fetch("", "MyBot/1.0") + assert result + print(f"✓ Empty URL handled: {'allowed' if result else 'denied'}") + + # Invalid URL + result = await parser.can_fetch("not_a_url", "MyBot/1.0") + assert result + print(f"✓ Invalid URL handled: {'allowed' if result else 'denied'}") + + # URL without scheme + result = await parser.can_fetch("example.com/page", "MyBot/1.0") + assert result + print(f"✓ URL without scheme handled: {'allowed' if result else 'denied'}") + + # 4. Test with local server + async def start_test_server(): + app = web.Application() + + async def robots_txt(request): + return web.Response( + text="""User-agent: * +Disallow: /private/ +Allow: /public/ +""" + ) + + async def malformed_robots(request): + return web.Response(text="<<>>") + + async def timeout_robots(request): + await asyncio.sleep(5) + return web.Response(text="Should timeout") + + async def empty_robots(request): + return web.Response(text="") + + async def giant_robots(request): + return web.Response(text="User-agent: *\nDisallow: /\n" * 10000) + + # Mount all handlers at root level + app.router.add_get("/robots.txt", robots_txt) + app.router.add_get("/malformed/robots.txt", malformed_robots) + app.router.add_get("/timeout/robots.txt", timeout_robots) + app.router.add_get("/empty/robots.txt", empty_robots) + app.router.add_get("/giant/robots.txt", giant_robots) + + runner = web.AppRunner(app) + await runner.setup() + site = web.TCPSite(runner, "localhost", 0) + await site.start() + return runner + + runner = await start_test_server() try: - # 1. Basic setup test - print("1. Testing basic initialization...") - parser = RobotsParser(cache_dir=temp_dir) - assert os.path.exists(parser.db_path), "Database file not created" - print("✓ Basic initialization passed") - - # 2. Test common cases - print("\n2. Testing common cases...") - allowed = await parser.can_fetch("https://www.example.com", "MyBot/1.0") - print(f"✓ Regular website fetch: {'allowed' if allowed else 'denied'}") - - # Test caching - print("Testing cache...") + print("\n4. Testing robots.txt rules...") + # Addresses are either IPv4 or IPv6, in both types the port is the second element. + port: int = runner.addresses[0][1] + base_url = f"http://localhost:{port}" + + # Test public access + result = await parser.can_fetch(f"{base_url}/public/page", "bot") + print(f"Public access (/public/page): {'allowed' if result else 'denied'}") + assert result, "Public path should be allowed" + + # Test private access + result = await parser.can_fetch(f"{base_url}/private/secret", "bot") + assert not result, "Private path should be denied" + + # Test malformed + result = await parser.can_fetch( + f"{base_url}/malformed/page", "bot" + ) + assert result, "Malformed robots.txt should be handled as allowed" + + # Test timeout start = time.time() - await parser.can_fetch("https://www.example.com", "MyBot/1.0") + result = await parser.can_fetch(f"{base_url}/timeout/page", "bot") duration = time.time() - start - print(f"✓ Cached lookup took: {duration*1000:.2f}ms") - assert duration < 0.03, "Cache lookup too slow" - - # 3. Edge cases - print("\n3. Testing edge cases...") - - # Empty URL - result = await parser.can_fetch("", "MyBot/1.0") - print(f"✓ Empty URL handled: {'allowed' if result else 'denied'}") - - # Invalid URL - result = await parser.can_fetch("not_a_url", "MyBot/1.0") - print(f"✓ Invalid URL handled: {'allowed' if result else 'denied'}") - - # URL without scheme - result = await parser.can_fetch("example.com/page", "MyBot/1.0") - print(f"✓ URL without scheme handled: {'allowed' if result else 'denied'}") - - # 4. Test with local server - async def start_test_server(): - app = web.Application() - - async def robots_txt(request): - return web.Response(text="""User-agent: * -Disallow: /private/ -Allow: /public/ -""") - - async def malformed_robots(request): - return web.Response(text="<<>>") - - async def timeout_robots(request): - await asyncio.sleep(5) - return web.Response(text="Should timeout") - - async def empty_robots(request): - return web.Response(text="") - - async def giant_robots(request): - return web.Response(text="User-agent: *\nDisallow: /\n" * 10000) - - # Mount all handlers at root level - app.router.add_get('/robots.txt', robots_txt) - app.router.add_get('/malformed/robots.txt', malformed_robots) - app.router.add_get('/timeout/robots.txt', timeout_robots) - app.router.add_get('/empty/robots.txt', empty_robots) - app.router.add_get('/giant/robots.txt', giant_robots) - - runner = web.AppRunner(app) - await runner.setup() - site = web.TCPSite(runner, 'localhost', 8080) - await site.start() - return runner - - runner = await start_test_server() - try: - print("\n4. Testing robots.txt rules...") - base_url = "http://localhost:8080" - - # Test public access - result = await parser.can_fetch(f"{base_url}/public/page", "bot") - print(f"Public access (/public/page): {'allowed' if result else 'denied'}") - assert result, "Public path should be allowed" - - # Test private access - result = await parser.can_fetch(f"{base_url}/private/secret", "bot") - print(f"Private access (/private/secret): {'allowed' if result else 'denied'}") - assert not result, "Private path should be denied" - - # Test malformed - result = await parser.can_fetch("http://localhost:8080/malformed/page", "bot") - print(f"✓ Malformed robots.txt handled: {'allowed' if result else 'denied'}") - - # Test timeout - start = time.time() - result = await parser.can_fetch("http://localhost:8080/timeout/page", "bot") - duration = time.time() - start - print(f"✓ Timeout handled (took {duration:.2f}s): {'allowed' if result else 'denied'}") - assert duration < 3, "Timeout not working" - - # Test empty - result = await parser.can_fetch("http://localhost:8080/empty/page", "bot") - print(f"✓ Empty robots.txt handled: {'allowed' if result else 'denied'}") - - # Test giant file - start = time.time() - result = await parser.can_fetch("http://localhost:8080/giant/page", "bot") - duration = time.time() - start - print(f"✓ Giant robots.txt handled (took {duration:.2f}s): {'allowed' if result else 'denied'}") - - finally: - await runner.cleanup() - - # 5. Cache manipulation - print("\n5. Testing cache manipulation...") - - # Clear expired - parser.clear_expired() - print("✓ Clear expired entries completed") - - # Clear all - parser.clear_cache() - print("✓ Clear all cache completed") - - # Test with custom TTL - custom_parser = RobotsParser(cache_dir=temp_dir, cache_ttl=1) # 1 second TTL - await custom_parser.can_fetch("https://www.example.com", "bot") - print("✓ Custom TTL fetch completed") - await asyncio.sleep(1.1) + assert result, "Timeout should be handled as allowed" + assert duration < 3, "Timeout not working" + + # Test empty + result = await parser.can_fetch(f"{base_url}/empty/page", "bot") + assert result, "Empty robots.txt should be handled as allowed" + + # Test giant file start = time.time() - await custom_parser.can_fetch("https://www.example.com", "bot") - print(f"✓ TTL expiry working (refetched after {time.time() - start:.2f}s)") + result = await parser.can_fetch(f"{base_url}/giant/page", "bot") + assert result, "Giant robots.txt should be handled as allowed" finally: - # Cleanup - shutil.rmtree(temp_dir) - print("\nTest cleanup completed") + await runner.cleanup() + + # 5. Cache manipulation + print("\n5. Testing cache manipulation...") + + # Clear expired + parser.clear_expired() + print("✓ Clear expired entries completed") + + # Clear all + parser.clear_cache() + print("✓ Clear all cache completed") + + # Test with custom TTL + custom_parser = RobotsParser(cache_dir=temp_dir, cache_ttl=1) # 1 second TTL + result = await custom_parser.can_fetch("https://www.example.com", "bot") + assert result, "Custom TTL fetch failed" + + await asyncio.sleep(1.1) + start = time.time() + result = await custom_parser.can_fetch("https://www.example.com", "bot") + assert result, "Custom TTL fetch failed after expiry" -async def main(): - try: - await test_robots_parser() - except Exception as e: - print(f"Test failed: {str(e)}") - raise if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file + import subprocess + + sys.exit(subprocess.call(["pytest", *sys.argv[1:], sys.argv[0]])) diff --git a/tests/20241401/test_schema_builder.py b/tests/20241401/test_schema_builder.py index 46d0e2401..d4b9ee8fd 100644 --- a/tests/20241401/test_schema_builder.py +++ b/tests/20241401/test_schema_builder.py @@ -1,17 +1,14 @@ # https://claude.ai/chat/c4bbe93d-fb54-44ce-92af-76b4c8086c6b # https://claude.ai/chat/c24a768c-d8b2-478a-acc7-d76d42a308da -import os, sys +import json +import os +import sys +from typing import Any, Optional -parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) -sys.path.append(parent_dir) -__location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__))) +import pytest -import asyncio -from crawl4ai import AsyncWebCrawler, BrowserConfig, CrawlerRunConfig, CacheMode -from crawl4ai.markdown_generation_strategy import DefaultMarkdownGenerator -from crawl4ai.extraction_strategy import JsonCssExtractionStrategy, JsonXPathExtractionStrategy -from crawl4ai.utils import preprocess_html_for_schema, JsonXPathExtractionStrategy -import json +from crawl4ai import LLMConfig +from crawl4ai.extraction_strategy import JsonCssExtractionStrategy # Test HTML - A complex job board with companies, departments, and positions test_html = """ @@ -83,30 +80,123 @@ """ -# Test cases -def test_schema_generation(): - # Test 1: No query (should extract everything) - print("\nTest 1: No Query (Full Schema)") - schema1 = JsonCssExtractionStrategy.generate_schema(test_html) - print(json.dumps(schema1, indent=2)) - - # Test 2: Query for just basic job info - print("\nTest 2: Basic Job Info Query") - query2 = "I only need job titles, salaries, and locations" - schema2 = JsonCssExtractionStrategy.generate_schema(test_html, query2) - print(json.dumps(schema2, indent=2)) - - # Test 3: Query for company and department structure - print("\nTest 3: Organizational Structure Query") - query3 = "Extract company details and department names, without position details" - schema3 = JsonCssExtractionStrategy.generate_schema(test_html, query3) - print(json.dumps(schema3, indent=2)) - - # Test 4: Query for specific skills tracking - print("\nTest 4: Skills Analysis Query") - query4 = "I want to analyze required skills across all positions" - schema4 = JsonCssExtractionStrategy.generate_schema(test_html, query4) - print(json.dumps(schema4, indent=2)) + +@pytest.fixture +def config() -> LLMConfig: + """Load OpenAI API key from environment variable. + + If the API key is not found, skip the test.""" + api_token: Optional[str] = os.environ.get("OPENAI_API_KEY") + if not api_token: + pytest.skip("OpenAI API key is required for this test") + + return LLMConfig(api_token=api_token) + + +def test_no_query_full_schema(config: LLMConfig): + """No query (should extract everything)""" + schema = JsonCssExtractionStrategy.generate_schema(test_html, llm_config=config) + assert schema + assert isinstance(schema, dict) + assert schema.get("name", "") + fields: list[dict[str, Any]] = schema.get("fields", []) + assert len(fields) == 6 + seen_positions: bool = False + for field in fields: + assert isinstance(field, dict) + if field.get("name", "") == "departments": + department_fields: list[dict[str, Any]] = field.get("fields", []) + assert len(department_fields) == 2 + for department_field in department_fields: + assert isinstance(department_field, dict) + if department_field.get("name", "") == "positions": + position_fields: list[dict[str, Any]] = department_field.get("fields", []) + assert len(position_fields) > 8 + seen_positions = True + assert seen_positions + + +@pytest.mark.skip(reason="LLM extraction can be unpredictable") +def test_basic_job_info(config: LLMConfig): + """Query for just basic job info""" + query = "I only need job titles, salaries, and locations" + schema = JsonCssExtractionStrategy.generate_schema(test_html, query=query, llm_config=config) + print(json.dumps(schema, indent=2)) + assert schema + assert isinstance(schema, dict) + assert schema.get("name", "") + fields: list[dict[str, Any]] = schema.get("fields", []) + assert len(fields) == 3 + seen_job_title: bool = False + seen_salary_range: bool = False + seen_location: bool = False + for field in fields: + assert isinstance(field, dict) + name: str = field.get("name", "") + if name == "job_title": + seen_job_title = True + elif name == "salary_range": + seen_salary_range = True + elif name == "location": + seen_location = True + + assert seen_job_title + assert seen_salary_range + assert seen_location + + +def test_company_and_department_structure(config: LLMConfig): + """Query for company and department structure""" + query = "Extract company details and department names, without position details" + schema = JsonCssExtractionStrategy.generate_schema(test_html, query=query, llm_config=config) + print(json.dumps(schema, indent=2)) + assert schema + assert isinstance(schema, dict) + assert schema.get("name", "") + fields: list[dict[str, Any]] = schema.get("fields", []) + assert len(fields) == 6 + seen_department_name: bool = False + for field in fields: + assert isinstance(field, dict) + if field.get("name", "") == "departments": + department_fields: list[dict[str, Any]] = field.get("fields", []) + assert len(department_fields) == 1 + for department_field in department_fields: + assert isinstance(department_field, dict) + if department_field.get("name", "") == "department_name": + seen_department_name = True + assert seen_department_name + + +@pytest.mark.skip(reason="LLM extraction can be unpredictable") +def test_specific_skills_tracking(config: LLMConfig): + """Query for specific skills tracking""" + query = "I want to analyze required skills across all positions" + schema = JsonCssExtractionStrategy.generate_schema(test_html, query=query, llm_config=config) + print(json.dumps(schema, indent=2)) + assert schema + assert isinstance(schema, dict) + assert schema.get("name", "") + fields: list[dict[str, Any]] = schema.get("fields", []) + seen_skills_required: bool = False + for field in fields: + assert isinstance(field, dict) + if field.get("name", "") == "departments": + department_fields: list[dict[str, Any]] = field.get("fields", []) + assert len(department_fields) == 2 + for department_field in department_fields: + assert isinstance(department_field, dict) + if department_field.get("name", "") == "positions": + position_fields: list[dict[str, Any]] = department_field.get("fields", []) + assert len(position_fields) + for position_field in position_fields: + assert isinstance(position_field, dict) + if position_field.get("name", "") == "skills_required": + seen_skills_required = True + assert seen_skills_required + if __name__ == "__main__": - test_schema_generation() \ No newline at end of file + import subprocess + + sys.exit(subprocess.call(["pytest", *sys.argv[1:], sys.argv[0]])) diff --git a/tests/20241401/test_stream.py b/tests/20241401/test_stream.py index 5614eb725..465dd1a42 100644 --- a/tests/20241401/test_stream.py +++ b/tests/20241401/test_stream.py @@ -1,50 +1,42 @@ -import os, sys -# append 2 parent directories to sys.path to import crawl4ai -parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) -sys.path.append(parent_dir) -parent_parent_dir = os.path.dirname(parent_dir) -sys.path.append(parent_parent_dir) +import sys -import asyncio -from crawl4ai import * +import pytest +from crawl4ai import AsyncWebCrawler, BrowserConfig, CacheMode, CrawlerRunConfig, PruningContentFilter +from crawl4ai.models import CrawlResultContainer +from crawl4ai.markdown_generation_strategy import DefaultMarkdownGenerator + + +@pytest.mark.asyncio async def test_crawler(): # Setup configurations browser_config = BrowserConfig(headless=True, verbose=False) crawler_config = CrawlerRunConfig( cache_mode=CacheMode.BYPASS, markdown_generator=DefaultMarkdownGenerator( - content_filter=PruningContentFilter( - threshold=0.48, - threshold_type="fixed", - min_word_threshold=0 - ) + content_filter=PruningContentFilter(threshold=0.48, threshold_type="fixed", min_word_threshold=0) ), ) # Test URLs - mix of different sites - urls = [ - "http://example.com", - "http://example.org", - "http://example.net", - ] * 10 # 15 total URLs + urls = ["http://example.com", "http://example.org", "http://example.net"] * 10 # 15 total URLs async with AsyncWebCrawler(config=browser_config) as crawler: print("\n=== Testing Streaming Mode ===") - async for result in await crawler.arun_many( - urls=urls, - config=crawler_config.clone(stream=True), - ): + async for result in await crawler.arun_many(urls=urls, config=crawler_config.clone(stream=True)): print(f"Received result for: {result.url} - Success: {result.success}") - + print("\n=== Testing Batch Mode ===") - results = await crawler.arun_many( - urls=urls, - config=crawler_config, - ) + results = await crawler.arun_many(urls=urls, config=crawler_config) + assert isinstance(results, CrawlResultContainer) + assert len(results) == len(urls), "Expected the same number of results as URLs" print(f"Received all {len(results)} results at once") for result in results: + assert result.success print(f"Batch result for: {result.url} - Success: {result.success}") + if __name__ == "__main__": - asyncio.run(test_crawler()) \ No newline at end of file + import subprocess + + sys.exit(subprocess.call(["pytest", *sys.argv[1:], sys.argv[0]])) diff --git a/tests/20241401/test_stream_dispatch.py b/tests/20241401/test_stream_dispatch.py index 0b5d004c5..9709d7300 100644 --- a/tests/20241401/test_stream_dispatch.py +++ b/tests/20241401/test_stream_dispatch.py @@ -1,39 +1,31 @@ -import os, sys -# append 2 parent directories to sys.path to import crawl4ai -parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) -sys.path.append(parent_dir) -parent_parent_dir = os.path.dirname(parent_dir) -sys.path.append(parent_parent_dir) +import sys +import pytest -import asyncio -from typing import List -from crawl4ai import * +from crawl4ai import AsyncWebCrawler, BrowserConfig, CacheMode, CrawlerRunConfig from crawl4ai.async_dispatcher import MemoryAdaptiveDispatcher +from crawl4ai.markdown_generation_strategy import DefaultMarkdownGenerator + +@pytest.mark.asyncio async def test_streaming(): browser_config = BrowserConfig(headless=True, verbose=True) crawler_config = CrawlerRunConfig( cache_mode=CacheMode.BYPASS, - markdown_generator=DefaultMarkdownGenerator( - # content_filter=PruningContentFilter( - # threshold=0.48, - # threshold_type="fixed", - # min_word_threshold=0 - # ) - ), + markdown_generator=DefaultMarkdownGenerator(), ) urls = ["http://example.com"] * 10 - + async with AsyncWebCrawler(config=browser_config) as crawler: - dispatcher = MemoryAdaptiveDispatcher( - max_session_permit=5, - check_interval=0.5 - ) - + dispatcher = MemoryAdaptiveDispatcher(max_session_permit=5, check_interval=0.5) + async for result in dispatcher.run_urls_stream(urls, crawler, crawler_config): + assert result.result.success print(f"Got result for {result.url} - Success: {result.result.success}") + if __name__ == "__main__": - asyncio.run(test_streaming()) \ No newline at end of file + import subprocess + + sys.exit(subprocess.call(["pytest", *sys.argv[1:], sys.argv[0]])) diff --git a/tests/20241401/tets_robot.py b/tests/20241401/tets_robot.py deleted file mode 100644 index 9bb30bb9e..000000000 --- a/tests/20241401/tets_robot.py +++ /dev/null @@ -1,62 +0,0 @@ -import asyncio -from crawl4ai import * - -async def test_real_websites(): - print("\n=== Testing Real Website Robots.txt Compliance ===\n") - - browser_config = BrowserConfig(headless=True, verbose=True) - async with AsyncWebCrawler(config=browser_config) as crawler: - - # Test cases with URLs - test_cases = [ - # Public sites that should be allowed - ("https://example.com", True), # Simple public site - ("https://httpbin.org/get", True), # API endpoint - - # Sites with known strict robots.txt - ("https://www.facebook.com/robots.txt", False), # Social media - ("https://www.google.com/search", False), # Search pages - - # Edge cases - ("https://api.github.com", True), # API service - ("https://raw.githubusercontent.com", True), # Content delivery - - # Non-existent/error cases - ("https://thisisnotarealwebsite.com", True), # Non-existent domain - ("https://localhost:12345", True), # Invalid port - ] - - for url, expected in test_cases: - print(f"\nTesting: {url}") - try: - config = CrawlerRunConfig( - cache_mode=CacheMode.BYPASS, - check_robots_txt=True, # Enable robots.txt checking - verbose=True - ) - - result = await crawler.arun(url=url, config=config) - allowed = result.success and not result.error_message - - print(f"Expected: {'allowed' if expected else 'denied'}") - print(f"Actual: {'allowed' if allowed else 'denied'}") - print(f"Status Code: {result.status_code}") - if result.error_message: - print(f"Error: {result.error_message}") - - # Optional: Print robots.txt content if available - if result.metadata and 'robots_txt' in result.metadata: - print(f"Robots.txt rules:\n{result.metadata['robots_txt']}") - - except Exception as e: - print(f"Test failed with error: {str(e)}") - -async def main(): - try: - await test_real_websites() - except Exception as e: - print(f"Test suite failed: {str(e)}") - raise - -if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file diff --git a/tests/async/__init__.py b/tests/async/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/async/conftest.py b/tests/async/conftest.py new file mode 100644 index 000000000..7be3fb066 --- /dev/null +++ b/tests/async/conftest.py @@ -0,0 +1,10 @@ +import pytest + +from .test_content_scraper_strategy import print_comparison_table, write_results_to_csv + + +@pytest.hookimpl +def pytest_sessionfinish(session: pytest.Session, exitstatus: int): + write_results_to_csv() + if session.config.getoption("verbose"): + print_comparison_table() diff --git a/tests/async/data/wikipedia.html b/tests/async/data/wikipedia.html new file mode 100644 index 000000000..330d441ee --- /dev/null +++ b/tests/async/data/wikipedia.html @@ -0,0 +1,961 @@ + + + + +Wikipedia, the free encyclopedia + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +Jump to content +
+
+
+ + + + +
+
+ + + + + +
+
+
+
+
+
+
+
+
+
+
+ +
+
+
+
+
+
+

Main Page

+
+
+
+
+
+
+ +
+
+ + + +
+
+
+
+
+ + +
+
+
+
+ +
From Wikipedia, the free encyclopedia
+
+
+ + +
+
+
+

Welcome to Wikipedia

,
+ +
+
+
+
+
+

From today's featured article

+
+
+Frederick Steele
Frederick Steele
+
+

Steele's Greenville expedition took place from April 2 to April 25, 1863, during the Vicksburg campaign of the American Civil War. Union forces commanded by Major General Frederick Steele (pictured) occupied Greenville, Mississippi, and operated in the surrounding area, to divert Confederate attention from a more important movement made in Louisiana by Major General John A. McClernand's corps. Minor skirmishing between the two sides occurred, particularly in the early stages of the expedition. More than 1,000 slaves were freed during the operation, and large quantities of supplies and animals were destroyed or removed from the area. Along with other operations, including Grierson's Raid, Steele's Greenville expedition distracted Confederate attention from McClernand's movement. Some historians have suggested that the Greenville expedition represented the Union war policy's shifting more towards expanding the war to Confederate social and economic structures and the Confederate homefront. (Full article...) +

+
+Recently featured:
+
+

Did you know ...

+
+
+
+Olivia Rodrigo performing "Logical"
Olivia Rodrigo performing "Logical"
+
+ + +
+
+
+

In the news

+
+
+Nightclub fire damage
Nightclub fire damage
+
+ + +
+

On this day

+
+

March 19: Saint Joseph's Day (Western Christianity) +

+
+
+Zhao Bing, Emperor of Song
Zhao Bing, Emperor of Song
+
+ +
+
+More anniversaries:
+
+
+
+
+

Today's featured picture

+
+ + + +
David Livingstone + +

David Livingstone (19 March 1813 – 1 May 1873) was a Scottish physician, Congregationalist, pioneer Christian missionary with the London Missionary Society, and an explorer in Africa. Livingstone was married to Mary Moffat Livingstone, from the prominent 18th-century Moffat missionary family. His fame as an explorer and his obsession with learning the sources of the Nile was founded on the belief that if he could solve that age-old mystery, his fame would give him the influence to end the East African Arab–Swahili slave trade. Livingstone's subsequent exploration of the central African watershed was the culmination of the classic period of European geographical discovery and colonial penetration of Africa. His missionary travels, "disappearance", and eventual death in Africa‍—‌and subsequent glorification as a posthumous national hero in 1874‍—‌led to the founding of several major central African Christian missionary initiatives carried forward in the era of the European "Scramble for Africa". This portrait by Thomas Annan was taken in 1864. +

+

Photograph credit: Thomas Annan; restored by Adam Cuerden

+ + +
+
+
+

Other areas of Wikipedia

+
+
  • Community portal – The central hub for editors, with resources, links, tasks, and announcements.
  • +
  • Village pump – Forum for discussions about Wikipedia itself, including policies and technical issues.
  • +
  • Site news – Sources of news about Wikipedia and the broader Wikimedia movement.
  • +
  • Teahouse – Ask basic questions about using or editing Wikipedia.
  • +
  • Help desk – Ask questions about using or editing Wikipedia.
  • +
  • Reference desk – Ask research questions about encyclopedic topics.
  • +
  • Content portals – A unique way to navigate the encyclopedia.
+
+

Wikipedia's sister projects

+
+

Wikipedia is written by volunteer editors and hosted by the Wikimedia Foundation, a non-profit organization that also hosts a range of other volunteer projects: +

+
+ +
+

Wikipedia languages

+
+
+
+ + + + +
+
+ + + + +
+
+ +
+ +
+
+
+
+
+ + + +
+ +
+
+ +
+
+
+
    + +
+
+ + + + \ No newline at end of file diff --git a/tests/async/data/wikipedia.md b/tests/async/data/wikipedia.md new file mode 100644 index 000000000..7fa107adc --- /dev/null +++ b/tests/async/data/wikipedia.md @@ -0,0 +1,179 @@ +# Wikipedia, the free encyclopedia +From Wikipedia, the free encyclopedia + +From today's featured article +----------------------------- + +[![Frederick Steele](https://upload.wikimedia.org/wikipedia/commons/7/72/Frederick_Steele_%28cropped%2C_retouched%29.jpg)](/wiki/File:Frederick_Steele_\(cropped,_retouched\).jpg "Frederick Steele") + +Frederick Steele + +**[Steele's Greenville expedition](https://en.wikipedia.org/wiki/Steele%27s_Greenville_expedition "Steele's Greenville expedition")** took place from April 2 to April 25, 1863, during the [Vicksburg campaign](https://en.wikipedia.org/wiki/Vicksburg_campaign "Vicksburg campaign") of the [American Civil War](https://en.wikipedia.org/wiki/American_Civil_War "American Civil War"). [Union](https://en.wikipedia.org/wiki/Union_\(American_Civil_War\) "Union (American Civil War)") forces commanded by Major General [Frederick Steele](https://en.wikipedia.org/wiki/Frederick_Steele "Frederick Steele") _(pictured)_ occupied [Greenville, Mississippi](https://en.wikipedia.org/wiki/Greenville,_Mississippi "Greenville, Mississippi"), and operated in the surrounding area, to divert [Confederate](https://en.wikipedia.org/wiki/Confederate_States_of_America "Confederate States of America") attention from a more important movement made in [Louisiana](https://en.wikipedia.org/wiki/Louisiana "Louisiana") by Major General [John A. McClernand](https://en.wikipedia.org/wiki/John_A._McClernand "John A. McClernand")'s corps. Minor skirmishing between the two sides occurred, particularly in the early stages of the expedition. More than 1,000 slaves were freed during the operation, and large quantities of supplies and animals were destroyed or removed from the area. Along with other operations, including [Grierson's Raid](https://en.wikipedia.org/wiki/Grierson%27s_Raid "Grierson's Raid"), Steele's Greenville expedition distracted Confederate attention from McClernand's movement. Some historians have suggested that the Greenville expedition represented the Union war policy's shifting more towards expanding the war to Confederate social and economic structures and the Confederate homefront. (**[Full article...](https://en.wikipedia.org/wiki/Steele%27s_Greenville_expedition "Steele's Greenville expedition")**) + +Did you know ... +---------------- + +[![Olivia Rodrigo performing "Logical"](https://upload.wikimedia.org/wikipedia/commons/8/86/OlivaRO2150524_%2834%29_%2853727521314%29.jpg)](/wiki/File:OlivaRO2150524_\(34\)_\(53727521314\).jpg "Olivia Rodrigo performing \"Logical\"") + +Olivia Rodrigo performing "Logical" + +* ... that "**[Logical](https://en.wikipedia.org/wiki/Logical_\(song\) "Logical (song)")**" was performed on a crescent moon _(pictured)_ suspended from the ceiling on Olivia Rodrigo's [Guts World Tour](https://en.wikipedia.org/wiki/Guts_World_Tour "Guts World Tour")? +* ... that **[a newly discovered bee](https://en.wikipedia.org/wiki/Hylaeus_paumako "Hylaeus paumako")** descends from a single ancestor that reached the Hawaiian Islands between 1 million and 1.5 million years ago? +* ... that **[a single company](https://en.wikipedia.org/wiki/EviCore "EviCore")** authorizes health-insurance coverage for more than one hundred million Americans? +* ... that politician **[Prasenjit Barman](https://en.wikipedia.org/wiki/Prasenjit_Barman "Prasenjit Barman")** was credited for leading the restoration of the [Cooch Behar Palace](https://en.wikipedia.org/wiki/Cooch_Behar_Palace "Cooch Behar Palace")? +* ... that **[Sound Transit](https://en.wikipedia.org/wiki/Sound_Transit "Sound Transit")** has 170 pieces of permanent public art at its stations and facilities? +* ... that football player **[DJ Pickett](https://en.wikipedia.org/wiki/DJ_Pickett "DJ Pickett")** was the first [All-American](https://en.wikipedia.org/wiki/All-America "All-America") at his high school since his uncle nearly 30 years prior? +* ... that _**[The Dedalus Book of Polish Fantasy](https://en.wikipedia.org/wiki/The_Dedalus_Book_of_Polish_Fantasy "The Dedalus Book of Polish Fantasy")**_ features stories spanning two centuries of Polish literary tradition, exploring the theme of [personification of evil](https://en.wikipedia.org/wiki/Devil "Devil")? +* ... that **[Roger Tocotes](https://en.wikipedia.org/wiki/Roger_Tocotes "Roger Tocotes")** was suspected by the Duke of Clarence of masterminding the Duchess of Clarence's death, but Tocotes avoided capture until the King got involved? +* ... that Saturday Morning Strippers restored the **[Frank Lloyd Wright Home and Studio](https://en.wikipedia.org/wiki/Frank_Lloyd_Wright_Home_and_Studio "Frank Lloyd Wright Home and Studio")**? + +In the news +----------- + +[![Nightclub fire damage](https://upload.wikimedia.org/wikipedia/commons/3/3c/Remains_of_night_club_in_Kochani_after_the_fire_VOA-full.jpg)](/wiki/File:Remains_of_night_club_in_Kochani_after_the_fire_VOA-full.jpg "Nightclub fire damage") + +Nightclub fire damage + +* **[Israeli attacks](https://en.wikipedia.org/wiki/March_2025_Israeli_attacks_on_the_Gaza_Strip "March 2025 Israeli attacks on the Gaza Strip")** on the [Gaza Strip](https://en.wikipedia.org/wiki/Gaza_Strip "Gaza Strip") kill more than 400 people, ending [the Gaza war ceasefire](https://en.wikipedia.org/wiki/2025_Gaza_war_ceasefire "2025 Gaza war ceasefire"). +* **[A nightclub fire](https://en.wikipedia.org/wiki/Ko%C4%8Dani_nightclub_fire "Kočani nightclub fire")** _(damage pictured)_ in [Kočani](https://en.wikipedia.org/wiki/Ko%C4%8Dani "Kočani"), North Macedonia, kills at least 59 people and injures more than 155 others. +* In [Yemen](https://en.wikipedia.org/wiki/Yemen "Yemen"), 53 people are killed after the United States launches **[air and naval strikes](https://en.wikipedia.org/wiki/March_2025_United_States_attacks_in_Yemen "March 2025 United States attacks in Yemen")**. +* At least 42 people are killed as a result of **[storms and tornadoes](https://en.wikipedia.org/wiki/Tornado_outbreak_of_March_13%E2%80%9316,_2025 "Tornado outbreak of March 13–16, 2025")** in the [Midwestern](https://en.wikipedia.org/wiki/Midwestern_United_States "Midwestern United States") and [Southern United States](https://en.wikipedia.org/wiki/Southern_United_States "Southern United States"). +* The [People's United Party](https://en.wikipedia.org/wiki/People%27s_United_Party "People's United Party"), led by [Johnny Briceño](https://en.wikipedia.org/wiki/Johnny_Brice%C3%B1o "Johnny Briceño"), wins **[the Belizean general election](https://en.wikipedia.org/wiki/2025_Belizean_general_election "2025 Belizean general election")**. + +On this day +----------- + +**[March 19](https://en.wikipedia.org/wiki/March_19 "March 19")**: **[Saint Joseph's Day](https://en.wikipedia.org/wiki/Saint_Joseph%27s_Day "Saint Joseph's Day")** (Western Christianity) + +[![Zhao Bing, Emperor of Song](https://upload.wikimedia.org/wikipedia/commons/5/57/Song_Modi.jpg)](/wiki/File:Song_Modi.jpg "Zhao Bing, Emperor of Song") + +Zhao Bing, Emperor of Song + +* [1279](https://en.wikipedia.org/wiki/1279 "1279") – [Mongol conquest of Song China](https://en.wikipedia.org/wiki/Mongol_conquest_of_the_Song_dynasty "Mongol conquest of the Song dynasty"): [Zhao Bing](https://en.wikipedia.org/wiki/Zhao_Bing "Zhao Bing") _(pictured)_, the last **[Song emperor](https://en.wikipedia.org/wiki/List_of_emperors_of_the_Song_dynasty "List of emperors of the Song dynasty")**, drowned at the end of the [Battle of Yamen](https://en.wikipedia.org/wiki/Battle_of_Yamen "Battle of Yamen"), bringing the [Song dynasty](https://en.wikipedia.org/wiki/Song_dynasty "Song dynasty") to an end after three centuries. +* [1824](https://en.wikipedia.org/wiki/1824 "1824") – American explorer **[Benjamin Morrell](https://en.wikipedia.org/wiki/Benjamin_Morrell "Benjamin Morrell")** departed Antarctica after a voyage later plagued by claims of fraud. +* [1944](https://en.wikipedia.org/wiki/1944 "1944") – The secular [oratorio](https://en.wikipedia.org/wiki/Oratorio "Oratorio") _**[A Child of Our Time](https://en.wikipedia.org/wiki/A_Child_of_Our_Time "A Child of Our Time")**_ by **[Michael Tippett](https://en.wikipedia.org/wiki/Michael_Tippett "Michael Tippett")** premiered at the [Adelphi Theatre](https://en.wikipedia.org/wiki/Adelphi_Theatre "Adelphi Theatre") in London. +* [1998](https://en.wikipedia.org/wiki/1998 "1998") – An unscheduled [Ariana Afghan Airlines](https://en.wikipedia.org/wiki/Ariana_Afghan_Airlines "Ariana Afghan Airlines") flight **[crashed into a mountain](https://en.wikipedia.org/wiki/1998_Ariana_Afghan_Airlines_Boeing_727_crash "1998 Ariana Afghan Airlines Boeing 727 crash")** on approach into Kabul, killing all 45 people aboard. +* [2011](https://en.wikipedia.org/wiki/2011 "2011") – [First Libyan Civil War](https://en.wikipedia.org/wiki/Libyan_civil_war_\(2011\) "Libyan civil war (2011)"): The [French Air Force](https://en.wikipedia.org/wiki/French_Air_and_Space_Force "French Air and Space Force") launched **[Opération Harmattan](https://en.wikipedia.org/wiki/Op%C3%A9ration_Harmattan "Opération Harmattan")**, beginning [foreign military intervention in Libya](https://en.wikipedia.org/wiki/2011_military_intervention_in_Libya "2011 military intervention in Libya"). + +* **[Lord Edmund Howard](https://en.wikipedia.org/wiki/Lord_Edmund_Howard "Lord Edmund Howard")** (d. 1539) +* **[Greville Wynne](https://en.wikipedia.org/wiki/Greville_Wynne "Greville Wynne")** (b. 1919) +* **[Joe Gaetjens](https://en.wikipedia.org/wiki/Joe_Gaetjens "Joe Gaetjens")** (b. 1924) +* **[Lise Østergaard](https://en.wikipedia.org/wiki/Lise_%C3%98stergaard "Lise Østergaard")** (d. 1996) + +Today's featured picture +------------------------ + +Other areas of Wikipedia +------------------------ + +* **[Community portal](https://en.wikipedia.org/wiki/Wikipedia:Community_portal "Wikipedia:Community portal")** – The central hub for editors, with resources, links, tasks, and announcements. +* **[Village pump](https://en.wikipedia.org/wiki/Wikipedia:Village_pump "Wikipedia:Village pump")** – Forum for discussions about Wikipedia itself, including policies and technical issues. +* **[Site news](https://en.wikipedia.org/wiki/Wikipedia:News "Wikipedia:News")** – Sources of news about Wikipedia and the broader Wikimedia movement. +* **[Teahouse](https://en.wikipedia.org/wiki/Wikipedia:Teahouse "Wikipedia:Teahouse")** – Ask basic questions about using or editing Wikipedia. +* **[Help desk](https://en.wikipedia.org/wiki/Wikipedia:Help_desk "Wikipedia:Help desk")** – Ask questions about using or editing Wikipedia. +* **[Reference desk](https://en.wikipedia.org/wiki/Wikipedia:Reference_desk "Wikipedia:Reference desk")** – Ask research questions about encyclopedic topics. +* **[Content portals](https://en.wikipedia.org/wiki/Wikipedia:Contents/Portals "Wikipedia:Contents/Portals")** – A unique way to navigate the encyclopedia. + +Wikipedia's sister projects +--------------------------- + +Wikipedia is written by volunteer editors and hosted by the [Wikimedia Foundation](https://en.wikipedia.org/wiki/Wikimedia_Foundation "Wikimedia Foundation"), a non-profit organization that also hosts a range of other volunteer [projects](https://wikimediafoundation.org/our-work/wikimedia-projects/ "foundationsite:our-work/wikimedia-projects/"): + +* [![Commons logo](https://upload.wikimedia.org/wikipedia/en/4/4a/Commons-logo.svg)](https://commons.wikimedia.org/wiki/ "Commons") + +* [![MediaWiki logo](https://upload.wikimedia.org/wikipedia/commons/a/a6/MediaWiki-2020-icon.svg)](https://www.mediawiki.org/wiki/ "MediaWiki") + +* [![Meta-Wiki logo](https://upload.wikimedia.org/wikipedia/commons/7/75/Wikimedia_Community_Logo.svg)](https://meta.wikimedia.org/wiki/ "Meta-Wiki") + + [Meta-Wiki](https://meta.wikimedia.org/wiki/ "m:") + Wikimedia project coordination + +* [![Wikibooks logo](https://upload.wikimedia.org/wikipedia/commons/f/fa/Wikibooks-logo.svg)](https://en.wikibooks.org/wiki/ "Wikibooks") + +* [![Wikidata logo](https://upload.wikimedia.org/wikipedia/commons/f/ff/Wikidata-logo.svg)](https://www.wikidata.org/wiki/ "Wikidata") + +* [![Wikinews logo](https://upload.wikimedia.org/wikipedia/commons/2/24/Wikinews-logo.svg)](https://en.wikinews.org/wiki/ "Wikinews") + +* [![Wikiquote logo](https://upload.wikimedia.org/wikipedia/commons/f/fa/Wikiquote-logo.svg)](https://en.wikiquote.org/wiki/ "Wikiquote") + +* [![Wikisource logo](https://upload.wikimedia.org/wikipedia/commons/4/4c/Wikisource-logo.svg)](https://en.wikisource.org/wiki/ "Wikisource") + +* [![Wikispecies logo](https://upload.wikimedia.org/wikipedia/commons/d/df/Wikispecies-logo.svg)](https://species.wikimedia.org/wiki/ "Wikispecies") + +* [![Wikiversity logo](https://upload.wikimedia.org/wikipedia/commons/0/0b/Wikiversity_logo_2017.svg)](https://en.wikiversity.org/wiki/ "Wikiversity") + +* [![Wikivoyage logo](https://upload.wikimedia.org/wikipedia/commons/d/dd/Wikivoyage-Logo-v3-icon.svg)](https://en.wikivoyage.org/wiki/ "Wikivoyage") + +* [![Wiktionary logo](https://upload.wikimedia.org/wikipedia/en/0/06/Wiktionary-logo-v2.svg)](https://en.wiktionary.org/wiki/ "Wiktionary") +[![Wiktionary logo](https://upload.wikimedia.org/wikipedia/commons/e/ec/Wiktionary-logo.svg)](https://en.wiktionary.org/wiki/ "Wiktionary") + + +Wikipedia languages +------------------- + +This Wikipedia is written in [English](https://en.wikipedia.org/wiki/English_language "English language"). Many [other Wikipedias are available](https://meta.wikimedia.org/wiki/List_of_Wikipedias "meta:List of Wikipedias"); some of the largest are listed below. + +* * [العربية](https://ar.wikipedia.org/wiki/) + * [Deutsch](https://de.wikipedia.org/wiki/) + * [Español](https://es.wikipedia.org/wiki/) + * [فارسی](https://fa.wikipedia.org/wiki/)‎ + * [Français](https://fr.wikipedia.org/wiki/) + * [Italiano](https://it.wikipedia.org/wiki/) + * [Nederlands](https://nl.wikipedia.org/wiki/) + * [日本語](https://ja.wikipedia.org/wiki/) + * [Polski](https://pl.wikipedia.org/wiki/) + * [Português](https://pt.wikipedia.org/wiki/) + * [Русский](https://ru.wikipedia.org/wiki/) + * [Svenska](https://sv.wikipedia.org/wiki/) + * [Українська](https://uk.wikipedia.org/wiki/) + * [Tiếng Việt](https://vi.wikipedia.org/wiki/) + * [中文](https://zh.wikipedia.org/wiki/) + +* * [Bahasa Indonesia](https://id.wikipedia.org/wiki/) + * [Bahasa Melayu](https://ms.wikipedia.org/wiki/) + * [Bân-lâm-gú](https://zh-min-nan.wikipedia.org/wiki/) + * [Български](https://bg.wikipedia.org/wiki/) + * [Català](https://ca.wikipedia.org/wiki/) + * [Čeština](https://cs.wikipedia.org/wiki/) + * [Dansk](https://da.wikipedia.org/wiki/) + * [Esperanto](https://eo.wikipedia.org/wiki/) + * [Euskara](https://eu.wikipedia.org/wiki/) + * [עברית](https://he.wikipedia.org/wiki/) + * [Հայերեն](https://hy.wikipedia.org/wiki/) + * [한국어](https://ko.wikipedia.org/wiki/) + * [Magyar](https://hu.wikipedia.org/wiki/) + * [Norsk bokmål](https://no.wikipedia.org/wiki/) + * [Română](https://ro.wikipedia.org/wiki/) + * [Simple English](https://simple.wikipedia.org/wiki/) + * [Slovenčina](https://sk.wikipedia.org/wiki/) + * [Srpski](https://sr.wikipedia.org/wiki/) + * [Srpskohrvatski](https://sh.wikipedia.org/wiki/) + * [Suomi](https://fi.wikipedia.org/wiki/) + * [Türkçe](https://tr.wikipedia.org/wiki/) + * [Oʻzbekcha](https://uz.wikipedia.org/wiki/) + +* * [Asturianu](https://ast.wikipedia.org/wiki/) + * [Azərbaycanca](https://az.wikipedia.org/wiki/) + * [বাংলা](https://bn.wikipedia.org/wiki/) + * [Bosanski](https://bs.wikipedia.org/wiki/) + * [کوردی](https://ckb.wikipedia.org/wiki/) + * [Eesti](https://et.wikipedia.org/wiki/) + * [Ελληνικά](https://el.wikipedia.org/wiki/) + * [Frysk](https://fy.wikipedia.org/wiki/) + * [Gaeilge](https://ga.wikipedia.org/wiki/) + * [Galego](https://gl.wikipedia.org/wiki/) + * [Hrvatski](https://hr.wikipedia.org/wiki/) + * [ქართული](https://ka.wikipedia.org/wiki/) + * [Kurdî](https://ku.wikipedia.org/wiki/) + * [Latviešu](https://lv.wikipedia.org/wiki/) + * [Lietuvių](https://lt.wikipedia.org/wiki/) + * [മലയാളം](https://ml.wikipedia.org/wiki/) + * [Македонски](https://mk.wikipedia.org/wiki/) + * [မြန်မာဘာသာ](https://my.wikipedia.org/wiki/) + * [Norsk nynorsk](https://nn.wikipedia.org/wiki/) + * [ਪੰਜਾਬੀ](https://pa.wikipedia.org/wiki/) + * [Shqip](https://sq.wikipedia.org/wiki/) + * [Slovenščina](https://sl.wikipedia.org/wiki/) + * [ไทย](https://th.wikipedia.org/wiki/) + * [తెలుగు](https://te.wikipedia.org/wiki/) + * [اردو](https://ur.wikipedia.org/wiki/) diff --git a/tests/async/test_0.4.2_browser_manager.py b/tests/async/test_0_4_2_browser_manager.py similarity index 70% rename from tests/async/test_0.4.2_browser_manager.py rename to tests/async/test_0_4_2_browser_manager.py index 21b4be11b..4ee1104f3 100644 --- a/tests/async/test_0.4.2_browser_manager.py +++ b/tests/async/test_0_4_2_browser_manager.py @@ -1,24 +1,25 @@ -import os import sys -import asyncio -from crawl4ai import AsyncWebCrawler, CacheMode -from crawl4ai.markdown_generation_strategy import DefaultMarkdownGenerator +from http.server import BaseHTTPRequestHandler, HTTPServer +from pathlib import Path +import threading -parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) -sys.path.append(parent_dir) -__location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__))) +import pytest +import requests +from crawl4ai import AsyncWebCrawler, CacheMode +from crawl4ai.markdown_generation_strategy import DefaultMarkdownGenerator # Assuming that the changes made allow different configurations # for managed browser, persistent context, and so forth. +@pytest.mark.asyncio async def test_default_headless(): async with AsyncWebCrawler( headless=True, verbose=True, user_agent_mode="random", - user_agent_generator_config={"device_type": "mobile", "os_type": "android"}, + user_agent_generator_config={"platforms": ["mobile"], "os": "android"}, use_managed_browser=False, use_persistent_context=False, ignore_https_errors=True, @@ -29,20 +30,24 @@ async def test_default_headless(): cache_mode=CacheMode.BYPASS, markdown_generator=DefaultMarkdownGenerator(options={"ignore_links": True}), ) + assert result.success + assert result.html print("[test_default_headless] success:", result.success) print("HTML length:", len(result.html if result.html else "")) -async def test_managed_browser_persistent(): +@pytest.mark.asyncio +async def test_managed_browser_persistent(tmp_path: Path): # Treating use_persistent_context=True as managed_browser scenario. + user_data_dir: Path = tmp_path / "user_data_dir" async with AsyncWebCrawler( headless=False, verbose=True, user_agent_mode="random", - user_agent_generator_config={"device_type": "desktop", "os_type": "mac"}, + user_agent_generator_config={"platforms": ["desktop"], "os": "mac"}, use_managed_browser=True, use_persistent_context=True, # now should behave same as managed browser - user_data_dir="./outpu/test_profile", + user_data_dir=user_data_dir.as_posix(), # This should store and reuse profile data across runs ) as crawler: result = await crawler.arun( @@ -50,10 +55,11 @@ async def test_managed_browser_persistent(): cache_mode=CacheMode.BYPASS, markdown_generator=DefaultMarkdownGenerator(options={"ignore_links": True}), ) - print("[test_managed_browser_persistent] success:", result.success) - print("HTML length:", len(result.html if result.html else "")) + assert result.success + assert result.html +@pytest.mark.asyncio async def test_session_reuse(): # Test creating a session, using it for multiple calls session_id = "my_session" @@ -72,7 +78,7 @@ async def test_session_reuse(): session_id=session_id, markdown_generator=DefaultMarkdownGenerator(options={"ignore_links": True}), ) - print("[test_session_reuse first call] success:", result1.success) + assert result1.success # Second call: same session, possibly cookie retained result2 = await crawler.arun( @@ -81,16 +87,17 @@ async def test_session_reuse(): session_id=session_id, markdown_generator=DefaultMarkdownGenerator(options={"ignore_links": True}), ) - print("[test_session_reuse second call] success:", result2.success) + assert result2.success +@pytest.mark.asyncio async def test_magic_mode(): # Test magic mode with override_navigator and simulate_user async with AsyncWebCrawler( headless=False, verbose=True, user_agent_mode="random", - user_agent_generator_config={"device_type": "desktop", "os_type": "windows"}, + user_agent_generator_config={"platforms": ["desktop"], "os": "windows"}, use_managed_browser=False, use_persistent_context=False, magic=True, @@ -102,30 +109,55 @@ async def test_magic_mode(): cache_mode=CacheMode.BYPASS, markdown_generator=DefaultMarkdownGenerator(options={"ignore_links": True}), ) - print("[test_magic_mode] success:", result.success) - print("HTML length:", len(result.html if result.html else "")) + assert result.success + assert result.html + +class ProxyHandler(BaseHTTPRequestHandler): + """Simple HTTP proxy handler for testing purposes.""" + def do_GET(self): + resp = requests.get(self.path) + self.send_response(resp.status_code) + for k, v in resp.headers.items(): + self.send_header(k, v) + self.end_headers() + self.wfile.write(resp.content) -async def test_proxy_settings(): +@pytest.fixture +def proxy_server(): + """Fixture to create a simple HTTP proxy server for testing.""" + server = HTTPServer(('localhost', 0), ProxyHandler) + port = server.server_address[1] + + thread = threading.Thread(target=server.serve_forever) + thread.daemon = True + thread.start() + + yield f"http://localhost:{port}" + + server.shutdown() + thread.join() + +@pytest.mark.asyncio +async def test_proxy_settings(proxy_server: str): # Test with a proxy (if available) to ensure code runs with proxy async with AsyncWebCrawler( headless=True, verbose=False, user_agent="Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36", - proxy="http://127.0.0.1:8080", # Assuming local proxy server for test + proxy=proxy_server, use_managed_browser=False, use_persistent_context=False, ) as crawler: result = await crawler.arun( - url="https://httpbin.org/ip", + url="http://httpbin.org/ip", cache_mode=CacheMode.BYPASS, markdown_generator=DefaultMarkdownGenerator(options={"ignore_links": True}), ) - print("[test_proxy_settings] success:", result.success) - if result.success: - print("HTML preview:", result.html[:200] if result.html else "") + assert result.success +@pytest.mark.asyncio async def test_ignore_https_errors(): # Test ignore HTTPS errors with a self-signed or invalid cert domain # This is just conceptual, the domain should be one that triggers SSL error. @@ -143,18 +175,10 @@ async def test_ignore_https_errors(): cache_mode=CacheMode.BYPASS, markdown_generator=DefaultMarkdownGenerator(options={"ignore_links": True}), ) - print("[test_ignore_https_errors] success:", result.success) - - -async def main(): - print("Running tests...") - # await test_default_headless() - # await test_managed_browser_persistent() - # await test_session_reuse() - # await test_magic_mode() - # await test_proxy_settings() - await test_ignore_https_errors() + assert result.success if __name__ == "__main__": - asyncio.run(main()) + import subprocess + + sys.exit(subprocess.call(["pytest", *sys.argv[1:], sys.argv[0]])) diff --git a/tests/async/test_0.4.2_config_params.py b/tests/async/test_0_4_2_config_params.py similarity index 82% rename from tests/async/test_0.4.2_config_params.py rename to tests/async/test_0_4_2_config_params.py index 9a15f864d..1bc6d2dc0 100644 --- a/tests/async/test_0.4.2_config_params.py +++ b/tests/async/test_0_4_2_config_params.py @@ -1,18 +1,18 @@ -import os, sys +import os +import sys -parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) -sys.path.append(parent_dir) -__location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__))) +from httpx import codes +import pytest -import asyncio -from crawl4ai import AsyncWebCrawler, CacheMode +from crawl4ai import AsyncWebCrawler, CacheMode, DefaultMarkdownGenerator from crawl4ai.async_configs import BrowserConfig, CrawlerRunConfig +from crawl4ai.chunking_strategy import RegexChunking from crawl4ai.content_filter_strategy import PruningContentFilter from crawl4ai.extraction_strategy import JsonCssExtractionStrategy -from crawl4ai.chunking_strategy import RegexChunking # Category 1: Browser Configuration Tests +@pytest.mark.asyncio async def test_browser_config_object(): """Test the new BrowserConfig object with various browser settings""" browser_config = BrowserConfig( @@ -22,7 +22,7 @@ async def test_browser_config_object(): viewport_height=1080, use_managed_browser=True, user_agent_mode="random", - user_agent_generator_config={"device_type": "desktop", "os_type": "windows"}, + user_agent_generator_config={"os": "windows"}, ) async with AsyncWebCrawler(config=browser_config, verbose=True) as crawler: @@ -31,6 +31,7 @@ async def test_browser_config_object(): assert len(result.html) > 0, "No HTML content retrieved" +@pytest.mark.asyncio async def test_browser_performance_config(): """Test browser configurations focused on performance""" browser_config = BrowserConfig( @@ -44,10 +45,11 @@ async def test_browser_performance_config(): async with AsyncWebCrawler(config=browser_config) as crawler: result = await crawler.arun("https://example.com") assert result.success, "Performance optimized crawl failed" - assert result.status_code == 200, "Unexpected status code" + assert result.status_code == codes.OK, "Unexpected status code" # Category 2: Content Processing Tests +@pytest.mark.asyncio async def test_content_extraction_config(): """Test content extraction with various strategies""" crawler_config = CrawlerRunConfig( @@ -60,7 +62,9 @@ async def test_content_extraction_config(): } ), chunking_strategy=RegexChunking(), - content_filter=PruningContentFilter(), + markdown_generator=DefaultMarkdownGenerator( + content_filter=PruningContentFilter(), + ), ) async with AsyncWebCrawler() as crawler: @@ -72,6 +76,7 @@ async def test_content_extraction_config(): # Category 3: Cache and Session Management Tests +@pytest.mark.asyncio async def test_cache_and_session_management(): """Test different cache modes and session handling""" browser_config = BrowserConfig(use_persistent_context=True) @@ -93,9 +98,10 @@ async def test_cache_and_session_management(): # Category 4: Media Handling Tests +@pytest.mark.asyncio async def test_media_handling_config(): """Test configurations related to media handling""" - # Get the base path for home directroy ~/.crawl4ai/downloads, make sure it exists + # Get the base path for home directory ~/.crawl4ai/downloads, make sure it exists os.makedirs(os.path.expanduser("~/.crawl4ai/downloads"), exist_ok=True) browser_config = BrowserConfig( viewport_width=1920, @@ -118,6 +124,7 @@ async def test_media_handling_config(): # Category 5: Anti-Bot and Site Interaction Tests +@pytest.mark.asyncio async def test_antibot_config(): """Test configurations for handling anti-bot measures""" crawler_config = CrawlerRunConfig( @@ -136,6 +143,7 @@ async def test_antibot_config(): # Category 6: Parallel Processing Tests +@pytest.mark.asyncio async def test_parallel_processing(): """Test parallel processing capabilities""" crawler_config = CrawlerRunConfig(mean_delay=0.5, max_range=1.0, semaphore_count=5) @@ -149,6 +157,7 @@ async def test_parallel_processing(): # Category 7: Backwards Compatibility Tests +@pytest.mark.asyncio async def test_legacy_parameter_support(): """Test that legacy parameters still work""" async with AsyncWebCrawler( @@ -158,13 +167,14 @@ async def test_legacy_parameter_support(): "https://example.com", screenshot=True, word_count_threshold=200, - bypass_cache=True, + cache_mode=CacheMode.BYPASS, css_selector=".main-content", ) assert result.success, "Legacy parameter support failed" # Category 8: Mixed Configuration Tests +@pytest.mark.asyncio async def test_mixed_config_usage(): """Test mixing new config objects with legacy parameters""" browser_config = BrowserConfig(headless=True) @@ -184,28 +194,6 @@ async def test_mixed_config_usage(): if __name__ == "__main__": + import subprocess - async def run_tests(): - test_functions = [ - test_browser_config_object, - # test_browser_performance_config, - # test_content_extraction_config, - # test_cache_and_session_management, - # test_media_handling_config, - # test_antibot_config, - # test_parallel_processing, - # test_legacy_parameter_support, - # test_mixed_config_usage - ] - - for test in test_functions: - print(f"\nRunning {test.__name__}...") - try: - await test() - print(f"✓ {test.__name__} passed") - except AssertionError as e: - print(f"✗ {test.__name__} failed: {str(e)}") - except Exception as e: - print(f"✗ {test.__name__} error: {str(e)}") - - asyncio.run(run_tests()) + sys.exit(subprocess.call(["pytest", *sys.argv[1:], sys.argv[0]])) diff --git a/tests/async/test_async_doanloader.py b/tests/async/test_async_doanloader.py deleted file mode 100644 index 055886cbf..000000000 --- a/tests/async/test_async_doanloader.py +++ /dev/null @@ -1,247 +0,0 @@ -import os -import sys -import asyncio -import shutil -from typing import List -import tempfile - -# Add the parent directory to the Python path -parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) -sys.path.append(parent_dir) - -from crawl4ai.async_webcrawler import AsyncWebCrawler - - -class TestDownloads: - def __init__(self): - self.temp_dir = tempfile.mkdtemp(prefix="crawl4ai_test_") - self.download_dir = os.path.join(self.temp_dir, "downloads") - os.makedirs(self.download_dir, exist_ok=True) - self.results: List[str] = [] - - def cleanup(self): - shutil.rmtree(self.temp_dir) - - def log_result(self, test_name: str, success: bool, message: str = ""): - result = f"{'✅' if success else '❌'} {test_name}: {message}" - self.results.append(result) - print(result) - - async def test_basic_download(self): - """Test basic file download functionality""" - try: - async with AsyncWebCrawler( - accept_downloads=True, downloads_path=self.download_dir, verbose=True - ) as crawler: - # Python.org downloads page typically has stable download links - result = await crawler.arun( - url="https://www.python.org/downloads/", - js_code=""" - // Click first download link - const downloadLink = document.querySelector('a[href$=".exe"]'); - if (downloadLink) downloadLink.click(); - """, - ) - - success = ( - result.downloaded_files is not None - and len(result.downloaded_files) > 0 - ) - self.log_result( - "Basic Download", - success, - f"Downloaded {len(result.downloaded_files or [])} files" - if success - else "No files downloaded", - ) - except Exception as e: - self.log_result("Basic Download", False, str(e)) - - async def test_persistent_context_download(self): - """Test downloads with persistent context""" - try: - user_data_dir = os.path.join(self.temp_dir, "user_data") - os.makedirs(user_data_dir, exist_ok=True) - - async with AsyncWebCrawler( - accept_downloads=True, - downloads_path=self.download_dir, - use_persistent_context=True, - user_data_dir=user_data_dir, - verbose=True, - ) as crawler: - result = await crawler.arun( - url="https://www.python.org/downloads/", - js_code=""" - const downloadLink = document.querySelector('a[href$=".exe"]'); - if (downloadLink) downloadLink.click(); - """, - ) - - success = ( - result.downloaded_files is not None - and len(result.downloaded_files) > 0 - ) - self.log_result( - "Persistent Context Download", - success, - f"Downloaded {len(result.downloaded_files or [])} files" - if success - else "No files downloaded", - ) - except Exception as e: - self.log_result("Persistent Context Download", False, str(e)) - - async def test_multiple_downloads(self): - """Test multiple simultaneous downloads""" - try: - async with AsyncWebCrawler( - accept_downloads=True, downloads_path=self.download_dir, verbose=True - ) as crawler: - result = await crawler.arun( - url="https://www.python.org/downloads/", - js_code=""" - // Click multiple download links - const downloadLinks = document.querySelectorAll('a[href$=".exe"]'); - downloadLinks.forEach(link => link.click()); - """, - ) - - success = ( - result.downloaded_files is not None - and len(result.downloaded_files) > 1 - ) - self.log_result( - "Multiple Downloads", - success, - f"Downloaded {len(result.downloaded_files or [])} files" - if success - else "Not enough files downloaded", - ) - except Exception as e: - self.log_result("Multiple Downloads", False, str(e)) - - async def test_different_browsers(self): - """Test downloads across different browser types""" - browsers = ["chromium", "firefox", "webkit"] - - for browser_type in browsers: - try: - async with AsyncWebCrawler( - accept_downloads=True, - downloads_path=self.download_dir, - browser_type=browser_type, - verbose=True, - ) as crawler: - result = await crawler.arun( - url="https://www.python.org/downloads/", - js_code=""" - const downloadLink = document.querySelector('a[href$=".exe"]'); - if (downloadLink) downloadLink.click(); - """, - ) - - success = ( - result.downloaded_files is not None - and len(result.downloaded_files) > 0 - ) - self.log_result( - f"{browser_type.title()} Download", - success, - f"Downloaded {len(result.downloaded_files or [])} files" - if success - else "No files downloaded", - ) - except Exception as e: - self.log_result(f"{browser_type.title()} Download", False, str(e)) - - async def test_edge_cases(self): - """Test various edge cases""" - - # Test 1: Downloads without specifying download path - try: - async with AsyncWebCrawler(accept_downloads=True, verbose=True) as crawler: - result = await crawler.arun( - url="https://www.python.org/downloads/", - js_code="document.querySelector('a[href$=\".exe\"]').click()", - ) - self.log_result( - "Default Download Path", - True, - f"Downloaded to default path: {result.downloaded_files[0] if result.downloaded_files else 'None'}", - ) - except Exception as e: - self.log_result("Default Download Path", False, str(e)) - - # Test 2: Downloads with invalid path - try: - async with AsyncWebCrawler( - accept_downloads=True, - downloads_path="/invalid/path/that/doesnt/exist", - verbose=True, - ) as crawler: - result = await crawler.arun( - url="https://www.python.org/downloads/", - js_code="document.querySelector('a[href$=\".exe\"]').click()", - ) - self.log_result( - "Invalid Download Path", False, "Should have raised an error" - ) - except Exception: - self.log_result( - "Invalid Download Path", True, "Correctly handled invalid path" - ) - - # Test 3: Download with accept_downloads=False - try: - async with AsyncWebCrawler(accept_downloads=False, verbose=True) as crawler: - result = await crawler.arun( - url="https://www.python.org/downloads/", - js_code="document.querySelector('a[href$=\".exe\"]').click()", - ) - success = result.downloaded_files is None - self.log_result( - "Disabled Downloads", - success, - "Correctly ignored downloads" - if success - else "Unexpectedly downloaded files", - ) - except Exception as e: - self.log_result("Disabled Downloads", False, str(e)) - - async def run_all_tests(self): - """Run all test cases""" - print("\n🧪 Running Download Tests...\n") - - test_methods = [ - self.test_basic_download, - self.test_persistent_context_download, - self.test_multiple_downloads, - self.test_different_browsers, - self.test_edge_cases, - ] - - for test in test_methods: - print(f"\n📝 Running {test.__doc__}...") - await test() - await asyncio.sleep(2) # Brief pause between tests - - print("\n📊 Test Results Summary:") - for result in self.results: - print(result) - - successes = len([r for r in self.results if "✅" in r]) - total = len(self.results) - print(f"\nTotal: {successes}/{total} tests passed") - - self.cleanup() - - -async def main(): - tester = TestDownloads() - await tester.run_all_tests() - - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/tests/async/test_async_downloader.py b/tests/async/test_async_downloader.py new file mode 100644 index 000000000..549de676e --- /dev/null +++ b/tests/async/test_async_downloader.py @@ -0,0 +1,178 @@ +import os +import sys +from pathlib import Path + +import pytest +from playwright.async_api import Browser, BrowserType, async_playwright + +from crawl4ai.async_configs import BrowserConfig +from crawl4ai.async_webcrawler import AsyncWebCrawler +from crawl4ai.models import CrawlResultContainer + + +@pytest.fixture(scope="session") +def downloads_path(tmp_path_factory: pytest.TempPathFactory) -> Path: + return tmp_path_factory.mktemp("downloads") + + +def assert_downloaded(result: CrawlResultContainer): + assert result.downloaded_files, "No files downloaded" + missing: list[str] = [] + + # Best effort to clean up downloaded files + for file in result.downloaded_files: + if not os.path.exists(file): + missing.append(file) + continue + + os.remove(file) + assert not missing, f"Files not downloaded: {missing}" + + +class TestDownloads: + @pytest.mark.asyncio + async def test_basic_download(self, downloads_path: Path): + """Test basic file download functionality.""" + async with AsyncWebCrawler( + accept_downloads=True, + downloads_path=downloads_path.as_posix(), + verbose=True, + ) as crawler: + # Python.org downloads page typically has stable download links + result: CrawlResultContainer = await crawler.arun( + url="https://www.python.org/downloads/", + js_code=""" + // Click first download link + const downloadLink = document.querySelector('a[href$=".exe"]'); + if (downloadLink) downloadLink.click(); + """, + ) + assert_downloaded(result) + + @pytest.mark.asyncio + async def test_persistent_context_download( + self, tmp_path_factory: pytest.TempPathFactory, downloads_path: Path + ): + """Test downloads with persistent context.""" + user_data_dir: Path = tmp_path_factory.mktemp("user_data") + + async with AsyncWebCrawler( + accept_downloads=True, + downloads_path=downloads_path, + use_persistent_context=True, + user_data_dir=user_data_dir, + verbose=True, + ) as crawler: + result: CrawlResultContainer = await crawler.arun( + url="https://www.python.org/downloads/", + js_code=""" + const downloadLink = document.querySelector('a[href$=".exe"]'); + if (downloadLink) downloadLink.click(); + """, + ) + assert_downloaded(result) + + @pytest.mark.asyncio + async def test_multiple_downloads(self, downloads_path: Path): + """Test multiple simultaneous downloads.""" + async with AsyncWebCrawler( + accept_downloads=True, downloads_path=downloads_path, verbose=True + ) as crawler: + result: CrawlResultContainer = await crawler.arun( + url="https://www.python.org/downloads/", + js_code=""" + // Click multiple download links + const downloadLinks = document.querySelectorAll('a[href$=".exe"]'); + downloadLinks.forEach(link => link.click()); + """, + ) + assert_downloaded(result) + + @pytest.mark.asyncio + @pytest.mark.parametrize("browser_type", ["chromium", "firefox", "webkit"]) + async def test_different_browsers(self, browser_type: str, downloads_path: Path): + """Test downloads across different browser types.""" + try: + # Check if the browser is installed and skip if not. + async with async_playwright() as p: + browsers: dict[str, BrowserType] = { + "chromium": p.chromium, + "firefox": p.firefox, + "webkit": p.webkit, + } + if browser_type not in browsers: + raise ValueError(f"Invalid browser type: {browser_type}") + bt: BrowserType = browsers[browser_type] + browser: Browser = await bt.launch(headless=True) + await browser.close() + except Exception as e: + if "Executable doesn't exist at" in str(e): + pytest.skip(f"{browser_type} is not installed: {e}") + return + raise + + async with AsyncWebCrawler( + config=BrowserConfig( + accept_downloads=True, + downloads_path=downloads_path, + browser_type=browser_type, + verbose=True, + ), + ) as crawler: + result: CrawlResultContainer = await crawler.arun( + url="https://www.python.org/downloads/", + js_code=""" + const downloadLink = document.querySelector('a[href$=".exe"]'); + if (downloadLink) downloadLink.click(); + """, + ) + assert_downloaded(result) + + @pytest.mark.asyncio + async def test_without_download_path(self): + async with AsyncWebCrawler(accept_downloads=True, verbose=True) as crawler: + result: CrawlResultContainer = await crawler.arun( + url="https://www.python.org/downloads/", + js_code=""" + const downloadLink = document.querySelector('a[href$=".exe"]'); + if (downloadLink) downloadLink.click(); + """, + ) + assert_downloaded(result) + for file in result.downloaded_files: # pyright: ignore[reportOptionalIterable] + if os.path.exists(file): + os.remove(file) + + @pytest.mark.asyncio + async def test_invalid_path(self): + with pytest.raises(ValueError): + async with AsyncWebCrawler( + accept_downloads=True, + downloads_path="/invalid\0/path/that/doesnt/exist", + verbose=True, + ) as crawler: + await crawler.arun( + url="https://www.python.org/downloads/", + js_code=""" + const downloadLink = document.querySelector('a[href$=".exe"]'); + if (downloadLink) downloadLink.click(); + """, + ) + + @pytest.mark.asyncio + async def test_accept_downloads_false(self): + async with AsyncWebCrawler(accept_downloads=False, verbose=True) as crawler: + result: CrawlResultContainer = await crawler.arun( + url="https://www.python.org/downloads/", + js_code=""" + const downloadLink = document.querySelector('a[href$=".exe"]'); + if (downloadLink) downloadLink.click(); + """, + ) + assert not result.downloaded_files, "Unexpectedly downloaded files" + + +if __name__ == "__main__": + import subprocess + + sys.exit(subprocess.call(["pytest", *sys.argv[1:], sys.argv[0]])) diff --git a/tests/async/test_basic_crawling.py b/tests/async/test_basic_crawling.py index ee4bb6339..76a4d5c12 100644 --- a/tests/async/test_basic_crawling.py +++ b/tests/async/test_basic_crawling.py @@ -1,14 +1,9 @@ -import os import sys -import pytest import time -# Add the parent directory to the Python path -parent_dir = os.path.dirname( - os.path.dirname(os.path.dirname(os.path.abspath(__file__))) -) -sys.path.append(parent_dir) +import pytest +from crawl4ai import CacheMode from crawl4ai.async_webcrawler import AsyncWebCrawler @@ -16,7 +11,7 @@ async def test_successful_crawl(): async with AsyncWebCrawler(verbose=True) as crawler: url = "https://www.nbcnews.com/business" - result = await crawler.arun(url=url, bypass_cache=True) + result = await crawler.arun(url=url, cache_mode=CacheMode.BYPASS) assert result.success assert result.url == url assert result.html @@ -28,7 +23,7 @@ async def test_successful_crawl(): async def test_invalid_url(): async with AsyncWebCrawler(verbose=True) as crawler: url = "https://www.invalidurl12345.com" - result = await crawler.arun(url=url, bypass_cache=True) + result = await crawler.arun(url=url, cache_mode=CacheMode.BYPASS) assert not result.success assert result.error_message @@ -41,7 +36,7 @@ async def test_multiple_urls(): "https://www.example.com", "https://www.python.org", ] - results = await crawler.arun_many(urls=urls, bypass_cache=True) + results = await crawler.arun_many(urls=urls, cache_mode=CacheMode.BYPASS) assert len(results) == len(urls) assert all(result.success for result in results) assert all(result.html for result in results) @@ -52,7 +47,7 @@ async def test_javascript_execution(): async with AsyncWebCrawler(verbose=True) as crawler: js_code = "document.body.innerHTML = '

Modified by JS

';" url = "https://www.example.com" - result = await crawler.arun(url=url, bypass_cache=True, js_code=js_code) + result = await crawler.arun(url=url, cache_mode=CacheMode.BYPASS, js_code=js_code) assert result.success assert "

Modified by JS

" in result.html @@ -69,7 +64,7 @@ async def test_concurrent_crawling_performance(): ] start_time = time.time() - results = await crawler.arun_many(urls=urls, bypass_cache=True) + results = await crawler.arun_many(urls=urls, cache_mode=CacheMode.BYPASS) end_time = time.time() total_time = end_time - start_time @@ -87,4 +82,6 @@ async def test_concurrent_crawling_performance(): # Entry point for debugging if __name__ == "__main__": - pytest.main([__file__, "-v"]) + import subprocess + + sys.exit(subprocess.call(["pytest", *sys.argv[1:], sys.argv[0]])) diff --git a/tests/async/test_caching.py b/tests/async/test_caching.py index d7f6efb54..c6d919538 100644 --- a/tests/async/test_caching.py +++ b/tests/async/test_caching.py @@ -1,12 +1,9 @@ -import os -import sys -import pytest import asyncio +import sys -# Add the parent directory to the Python path -parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) -sys.path.append(parent_dir) +import pytest +from crawl4ai import CacheMode from crawl4ai.async_webcrawler import AsyncWebCrawler @@ -17,7 +14,7 @@ async def test_caching(): # First crawl (should not use cache) start_time = asyncio.get_event_loop().time() - result1 = await crawler.arun(url=url, bypass_cache=True) + result1 = await crawler.arun(url=url, cache_mode=CacheMode.BYPASS) end_time = asyncio.get_event_loop().time() time_taken1 = end_time - start_time @@ -42,8 +39,8 @@ async def test_bypass_cache(): result1 = await crawler.arun(url=url, bypass_cache=False) assert result1.success - # Second crawl with bypass_cache=True - result2 = await crawler.arun(url=url, bypass_cache=True) + # Second crawl with cache_mode=CacheMode.BYPASS + result2 = await crawler.arun(url=url, cache_mode=CacheMode.BYPASS) assert result2.success # Content should be different (or at least, not guaranteed to be the same) @@ -84,4 +81,6 @@ async def test_flush_cache(): # Entry point for debugging if __name__ == "__main__": - pytest.main([__file__, "-v"]) + import subprocess + + sys.exit(subprocess.call(["pytest", *sys.argv[1:], sys.argv[0]])) diff --git a/tests/async/test_chunking_and_extraction_strategies.py b/tests/async/test_chunking_and_extraction_strategies.py index 90e17a9d6..b8fdd28ee 100644 --- a/tests/async/test_chunking_and_extraction_strategies.py +++ b/tests/async/test_chunking_and_extraction_strategies.py @@ -1,16 +1,14 @@ +import json import os import sys -import pytest -import json -# Add the parent directory to the Python path -parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) -sys.path.append(parent_dir) +import pytest -from crawl4ai import LLMConfig +from crawl4ai import CacheMode +from crawl4ai.async_configs import LLMConfig from crawl4ai.async_webcrawler import AsyncWebCrawler from crawl4ai.chunking_strategy import RegexChunking -from crawl4ai.extraction_strategy import LLMExtractionStrategy +from crawl4ai.extraction_strategy import CosineStrategy, LLMExtractionStrategy, NoExtractionStrategy @pytest.mark.asyncio @@ -19,42 +17,44 @@ async def test_regex_chunking(): url = "https://www.nbcnews.com/business" chunking_strategy = RegexChunking(patterns=["\n\n"]) result = await crawler.arun( - url=url, chunking_strategy=chunking_strategy, bypass_cache=True + url=url, + chunking_strategy=chunking_strategy, + extraction_strategy=NoExtractionStrategy(), + cache_mode=CacheMode.BYPASS, ) assert result.success assert result.extracted_content chunks = json.loads(result.extracted_content) assert len(chunks) > 1 # Ensure multiple chunks were created - -# @pytest.mark.asyncio -# async def test_cosine_strategy(): -# async with AsyncWebCrawler(verbose=True) as crawler: -# url = "https://www.nbcnews.com/business" -# extraction_strategy = CosineStrategy(word_count_threshold=10, max_dist=0.2, linkage_method="ward", top_k=3, sim_threshold=0.3) -# result = await crawler.arun( -# url=url, -# extraction_strategy=extraction_strategy, -# bypass_cache=True -# ) -# assert result.success -# assert result.extracted_content -# extracted_data = json.loads(result.extracted_content) -# assert len(extracted_data) > 0 -# assert all('tags' in item for item in extracted_data) +@pytest.mark.asyncio +async def test_cosine_strategy(): + async with AsyncWebCrawler(verbose=True) as crawler: + url = "https://www.nbcnews.com/business" + extraction_strategy = CosineStrategy(word_count_threshold=10, max_dist=0.2, linkage_method="ward", top_k=3, sim_threshold=0.3) + result = await crawler.arun( + url=url, + extraction_strategy=extraction_strategy, + cache_mode=CacheMode.BYPASS + ) + assert result.success + assert result.extracted_content + extracted_data = json.loads(result.extracted_content) + assert len(extracted_data) > 0 + assert all('tags' in item for item in extracted_data) @pytest.mark.asyncio async def test_llm_extraction_strategy(): + if not os.getenv("OPENAI_API_KEY"): + pytest.skip("Skipping env OPENAI_API_KEY not set") async with AsyncWebCrawler(verbose=True) as crawler: url = "https://www.nbcnews.com/business" extraction_strategy = LLMExtractionStrategy( - llm_config=LLMConfig(provider="openai/gpt-4o-mini",api_token=os.getenv("OPENAI_API_KEY")), + llm_config=LLMConfig(provider="openai/gpt-4o-mini", api_token=os.getenv("OPENAI_API_KEY")), instruction="Extract only content related to technology", ) - result = await crawler.arun( - url=url, extraction_strategy=extraction_strategy, bypass_cache=True - ) + result = await crawler.arun(url=url, extraction_strategy=extraction_strategy, cache_mode=CacheMode.BYPASS) assert result.success assert result.extracted_content extracted_data = json.loads(result.extracted_content) @@ -62,25 +62,30 @@ async def test_llm_extraction_strategy(): assert all("content" in item for item in extracted_data) -# @pytest.mark.asyncio -# async def test_combined_chunking_and_extraction(): -# async with AsyncWebCrawler(verbose=True) as crawler: -# url = "https://www.nbcnews.com/business" -# chunking_strategy = RegexChunking(patterns=["\n\n"]) -# extraction_strategy = CosineStrategy(word_count_threshold=10, max_dist=0.2, linkage_method="ward", top_k=3, sim_threshold=0.3) -# result = await crawler.arun( -# url=url, -# chunking_strategy=chunking_strategy, -# extraction_strategy=extraction_strategy, -# bypass_cache=True -# ) -# assert result.success -# assert result.extracted_content -# extracted_data = json.loads(result.extracted_content) -# assert len(extracted_data) > 0 -# assert all('tags' in item for item in extracted_data) -# assert all('content' in item for item in extracted_data) +@pytest.mark.asyncio +async def test_combined_chunking_and_extraction(): + async with AsyncWebCrawler(verbose=True) as crawler: + url = "https://www.nbcnews.com/business" + chunking_strategy = RegexChunking(patterns=["\n\n"]) + extraction_strategy = CosineStrategy( + word_count_threshold=10, max_dist=0.2, linkage_method="ward", top_k=3, sim_threshold=0.3 + ) + result = await crawler.arun( + url=url, + chunking_strategy=chunking_strategy, + extraction_strategy=extraction_strategy, + cache_mode=CacheMode.BYPASS, + ) + assert result.success + assert result.extracted_content + extracted_data = json.loads(result.extracted_content) + assert len(extracted_data) > 0 + assert all('tags' in item for item in extracted_data) + assert all('content' in item for item in extracted_data) + # Entry point for debugging if __name__ == "__main__": - pytest.main([__file__, "-v"]) + import subprocess + + sys.exit(subprocess.call(["pytest", *sys.argv[1:], sys.argv[0]])) diff --git a/tests/async/test_content_extraction.py b/tests/async/test_content_extraction.py index 9372387ad..5237feece 100644 --- a/tests/async/test_content_extraction.py +++ b/tests/async/test_content_extraction.py @@ -1,11 +1,8 @@ -import os import sys -import pytest -# Add the parent directory to the Python path -parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) -sys.path.append(parent_dir) +import pytest +from crawl4ai import CacheMode from crawl4ai.async_webcrawler import AsyncWebCrawler @@ -13,7 +10,7 @@ async def test_extract_markdown(): async with AsyncWebCrawler(verbose=True) as crawler: url = "https://www.nbcnews.com/business" - result = await crawler.arun(url=url, bypass_cache=True) + result = await crawler.arun(url=url, cache_mode=CacheMode.BYPASS) assert result.success assert result.markdown assert isinstance(result.markdown, str) @@ -24,7 +21,7 @@ async def test_extract_markdown(): async def test_extract_cleaned_html(): async with AsyncWebCrawler(verbose=True) as crawler: url = "https://www.nbcnews.com/business" - result = await crawler.arun(url=url, bypass_cache=True) + result = await crawler.arun(url=url, cache_mode=CacheMode.BYPASS) assert result.success assert result.cleaned_html assert isinstance(result.cleaned_html, str) @@ -35,7 +32,7 @@ async def test_extract_cleaned_html(): async def test_extract_media(): async with AsyncWebCrawler(verbose=True) as crawler: url = "https://www.nbcnews.com/business" - result = await crawler.arun(url=url, bypass_cache=True) + result = await crawler.arun(url=url, cache_mode=CacheMode.BYPASS) assert result.success assert result.media media = result.media @@ -52,7 +49,7 @@ async def test_extract_media(): async def test_extract_links(): async with AsyncWebCrawler(verbose=True) as crawler: url = "https://www.nbcnews.com/business" - result = await crawler.arun(url=url, bypass_cache=True) + result = await crawler.arun(url=url, cache_mode=CacheMode.BYPASS) assert result.success assert result.links links = result.links @@ -70,7 +67,7 @@ async def test_extract_links(): async def test_extract_metadata(): async with AsyncWebCrawler(verbose=True) as crawler: url = "https://www.nbcnews.com/business" - result = await crawler.arun(url=url, bypass_cache=True) + result = await crawler.arun(url=url, cache_mode=CacheMode.BYPASS) assert result.success assert result.metadata metadata = result.metadata @@ -85,7 +82,7 @@ async def test_css_selector_extraction(): url = "https://www.nbcnews.com/business" css_selector = "h1, h2, h3" result = await crawler.arun( - url=url, bypass_cache=True, css_selector=css_selector + url=url, cache_mode=CacheMode.BYPASS, css_selector=css_selector ) assert result.success assert result.markdown @@ -94,4 +91,6 @@ async def test_css_selector_extraction(): # Entry point for debugging if __name__ == "__main__": - pytest.main([__file__, "-v"]) + import subprocess + + sys.exit(subprocess.call(["pytest", *sys.argv[1:], sys.argv[0]])) diff --git a/tests/async/test_content_filter_bm25.py b/tests/async/test_content_filter_bm25.py index f05a8af7b..a0dff98ec 100644 --- a/tests/async/test_content_filter_bm25.py +++ b/tests/async/test_content_filter_bm25.py @@ -1,10 +1,9 @@ -import os, sys +import sys +from typing import Union + import pytest from bs4 import BeautifulSoup - -# Add the parent directory to the Python path -parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) -sys.path.append(parent_dir) +from bs4.element import PageElement, Tag from crawl4ai.content_filter_strategy import BM25ContentFilter @@ -16,7 +15,7 @@ def basic_html(): Test Article - +

Main Heading

@@ -38,7 +37,7 @@ def wiki_html():

Article Title

-

Section 1

+

Section One

Short but important section header description.

Long paragraph with sufficient words to meet the minimum threshold. This paragraph continues with more text to ensure we have enough content for proper testing. We need to make sure this has enough words to pass our filters and be considered valid content for extraction purposes.

@@ -78,7 +77,9 @@ def test_user_query_override(self, basic_html): # Access internal state to verify query usage soup = BeautifulSoup(basic_html, "lxml") - extracted_query = filter.extract_page_query(soup.find("head")) + head: Union[PageElement, Tag, None] = soup.find("head") + assert isinstance(head, Tag) + extracted_query = filter.extract_page_query(soup, head) assert extracted_query == user_query assert "Test description" not in extracted_query @@ -89,7 +90,7 @@ def test_header_extraction(self, wiki_html): contents = filter.filter_content(wiki_html) combined_content = " ".join(contents).lower() - assert "section 1" in combined_content, "Should include section header" + assert "section one" in combined_content, "Should include section header" assert "article title" in combined_content, "Should include main title" def test_no_metadata_fallback(self, no_meta_html): @@ -106,7 +107,6 @@ def test_empty_input(self): """Test handling of empty input""" filter = BM25ContentFilter() assert filter.filter_content("") == [] - assert filter.filter_content(None) == [] def test_malformed_html(self): """Test handling of malformed HTML""" @@ -142,7 +142,7 @@ def test_large_content(self): """Test handling of large content blocks""" large_html = f""" -
{'

Test content. ' * 1000}

+

{'Test content. ' * 1000}

""" filter = BM25ContentFilter() @@ -180,4 +180,6 @@ def test_performance(self, basic_html): if __name__ == "__main__": - pytest.main([__file__]) + import subprocess + + sys.exit(subprocess.call(["pytest", *sys.argv[1:], sys.argv[0]])) diff --git a/tests/async/test_content_filter_prune.py b/tests/async/test_content_filter_prune.py index 1f75a9e1e..d443a5cd0 100644 --- a/tests/async/test_content_filter_prune.py +++ b/tests/async/test_content_filter_prune.py @@ -1,8 +1,6 @@ -import os, sys -import pytest +import sys -parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) -sys.path.append(parent_dir) +import pytest from crawl4ai.content_filter_strategy import PruningContentFilter @@ -86,11 +84,11 @@ def test_min_word_threshold(self, mixed_content_html): def test_threshold_types(self, basic_html): """Test fixed vs dynamic thresholds""" - fixed_filter = PruningContentFilter(threshold_type="fixed", threshold=0.48) - dynamic_filter = PruningContentFilter(threshold_type="dynamic", threshold=0.45) + fixed_filter = PruningContentFilter(threshold_type="fixed", threshold=1.1) + dynamic_filter = PruningContentFilter(threshold_type="dynamic", threshold=1.1) - fixed_contents = fixed_filter.filter_content(basic_html) - dynamic_contents = dynamic_filter.filter_content(basic_html) + fixed_contents = "".join(fixed_filter.filter_content(basic_html)) + dynamic_contents = "".join(dynamic_filter.filter_content(basic_html)) assert len(fixed_contents) != len( dynamic_contents @@ -120,7 +118,6 @@ def test_empty_input(self): """Test handling of empty input""" filter = PruningContentFilter() assert filter.filter_content("") == [] - assert filter.filter_content(None) == [] def test_malformed_html(self): """Test handling of malformed HTML""" @@ -167,4 +164,6 @@ def test_consistent_output(self, basic_html): if __name__ == "__main__": - pytest.main([__file__]) + import subprocess + + sys.exit(subprocess.call(["pytest", *sys.argv[1:], sys.argv[0]])) diff --git a/tests/async/test_content_scraper_strategy.py b/tests/async/test_content_scraper_strategy.py index e6caf2405..7db09ae9a 100644 --- a/tests/async/test_content_scraper_strategy.py +++ b/tests/async/test_content_scraper_strategy.py @@ -1,219 +1,232 @@ -import os +import csv import sys import time -import csv +from dataclasses import asdict, dataclass +from functools import lru_cache +from pathlib import Path +from typing import Any, List + +import pytest +from _pytest.mark import ParameterSet from tabulate import tabulate -from dataclasses import dataclass -from typing import List -parent_dir = os.path.dirname( - os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +from crawl4ai.content_scraping_strategy import ( + ContentScrapingStrategy, + WebScrapingStrategy, ) -sys.path.append(parent_dir) -__location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__))) - -from crawl4ai.content_scraping_strategy import WebScrapingStrategy from crawl4ai.content_scraping_strategy import ( WebScrapingStrategy as WebScrapingStrategyCurrent, ) -# from crawl4ai.content_scrapping_strategy_current import WebScrapingStrategy as WebScrapingStrategyCurrent +from crawl4ai.models import ScrapingResult @dataclass -class TestResult: +class Result: name: str + strategy: str success: bool images: int internal_links: int external_links: int - markdown_length: int + cleaned_html_length: int execution_time: float -class StrategyTester: - def __init__(self): - self.new_scraper = WebScrapingStrategy() - self.current_scraper = WebScrapingStrategyCurrent() - with open(__location__ + "/sample_wikipedia.html", "r", encoding="utf-8") as f: - self.WIKI_HTML = f.read() - self.results = {"new": [], "current": []} - - def run_test(self, name: str, **kwargs) -> tuple[TestResult, TestResult]: - results = [] - for scraper in [self.new_scraper, self.current_scraper]: - start_time = time.time() - result = scraper._get_content_of_website_optimized( - url="https://en.wikipedia.org/wiki/Test", html=self.WIKI_HTML, **kwargs - ) - execution_time = time.time() - start_time - - test_result = TestResult( - name=name, - success=result["success"], - images=len(result["media"]["images"]), - internal_links=len(result["links"]["internal"]), - external_links=len(result["links"]["external"]), - markdown_length=len(result["markdown"]), - execution_time=execution_time, - ) - results.append(test_result) - - return results[0], results[1] # new, current - - def run_all_tests(self): - test_cases = [ - ("Basic Extraction", {}), - ("Exclude Tags", {"excluded_tags": ["table", "div.infobox", "div.navbox"]}), - ("Word Threshold", {"word_count_threshold": 50}), - ("CSS Selector", {"css_selector": "div.mw-parser-output > p"}), - ( - "Link Exclusions", - { - "exclude_external_links": True, - "exclude_social_media_links": True, - "exclude_domains": ["facebook.com", "twitter.com"], - }, - ), - ( - "Media Handling", - { - "exclude_external_images": True, - "image_description_min_word_threshold": 20, - }, - ), - ("Text Only", {"only_text": True, "remove_forms": True}), - ("HTML Cleaning", {"clean_html": True, "keep_data_attributes": True}), - ( - "HTML2Text Options", - { - "html2text": { - "skip_internal_links": True, - "single_line_break": True, - "mark_code": True, - "preserve_tags": ["pre", "code"], - } - }, - ), +@lru_cache +@pytest.fixture +def wiki_html() -> str: + file_path: Path = Path(__file__).parent / "sample_wikipedia.html" + with file_path.open("r", encoding="utf-8") as f: + return f.read() + + +results: List[Result] = [] + + +def print_comparison_table(): + """Print comparison table of results.""" + if not results: + return + + table_data = [] + headers = [ + "Test Name", + "Strategy", + "Success", + "Images", + "Internal Links", + "External Links", + "Cleaned HTML Length", + "Time (s)", + ] + + all_results: List[tuple[str, Result, Result]] = [] + new_results = [result for result in results if result.strategy == "new"] + current_results = [result for result in results if result.strategy == "current"] + for new_result in new_results: + for current_result in current_results: + if new_result.name == current_result.name: + all_results.append((new_result.name, new_result, current_result)) + + for name, new_result, current_result in all_results: + # Check for differences + differences = [] + if new_result.images != current_result.images: + differences.append("images") + if new_result.internal_links != current_result.internal_links: + differences.append("internal_links") + if new_result.external_links != current_result.external_links: + differences.append("external_links") + if new_result.cleaned_html_length != current_result.cleaned_html_length: + differences.append("cleaned_html") + + # Add row for new strategy + new_row = [ + name, + "New", + new_result.success, + new_result.images, + new_result.internal_links, + new_result.external_links, + new_result.cleaned_html_length, + f"{new_result.execution_time:.3f}", + ] + table_data.append(new_row) + + # Add row for current strategy + current_row = [ + "", + "Current", + current_result.success, + current_result.images, + current_result.internal_links, + current_result.external_links, + current_result.cleaned_html_length, + f"{current_result.execution_time:.3f}", ] + table_data.append(current_row) - all_results = [] - for name, kwargs in test_cases: - try: - new_result, current_result = self.run_test(name, **kwargs) - all_results.append((name, new_result, current_result)) - except Exception as e: - print(f"Error in {name}: {str(e)}") - - self.save_results_to_csv(all_results) - self.print_comparison_table(all_results) - - def save_results_to_csv(self, all_results: List[tuple]): - csv_file = os.path.join(__location__, "strategy_comparison_results.csv") - with open(csv_file, "w", newline="") as f: - writer = csv.writer(f) - writer.writerow( - [ - "Test Name", - "Strategy", - "Success", - "Images", - "Internal Links", - "External Links", - "Markdown Length", - "Execution Time", - ] + # Add difference summary if any + if differences: + table_data.append( + ["", "⚠️ Differences", ", ".join(differences), "", "", "", "", ""] ) - for name, new_result, current_result in all_results: - writer.writerow( - [ - name, - "New", - new_result.success, - new_result.images, - new_result.internal_links, - new_result.external_links, - new_result.markdown_length, - f"{new_result.execution_time:.3f}", - ] - ) - writer.writerow( - [ - name, - "Current", - current_result.success, - current_result.images, - current_result.internal_links, - current_result.external_links, - current_result.markdown_length, - f"{current_result.execution_time:.3f}", - ] - ) - - def print_comparison_table(self, all_results: List[tuple]): - table_data = [] - headers = [ - "Test Name", - "Strategy", - "Success", - "Images", - "Internal Links", - "External Links", - "Markdown Length", - "Time (s)", - ] - - for name, new_result, current_result in all_results: - # Check for differences - differences = [] - if new_result.images != current_result.images: - differences.append("images") - if new_result.internal_links != current_result.internal_links: - differences.append("internal_links") - if new_result.external_links != current_result.external_links: - differences.append("external_links") - if new_result.markdown_length != current_result.markdown_length: - differences.append("markdown") - - # Add row for new strategy - new_row = [ - name, - "New", - new_result.success, - new_result.images, - new_result.internal_links, - new_result.external_links, - new_result.markdown_length, - f"{new_result.execution_time:.3f}", - ] - table_data.append(new_row) - - # Add row for current strategy - current_row = [ - "", - "Current", - current_result.success, - current_result.images, - current_result.internal_links, - current_result.external_links, - current_result.markdown_length, - f"{current_result.execution_time:.3f}", + # Add empty row for better readability + table_data.append([""] * len(headers)) + + print("\nStrategy Comparison Results:") + print(tabulate(table_data, headers=headers, tablefmt="grid")) + + +def write_results_to_csv(): + """Write results to CSV and print comparison table.""" + if not results: + return + csv_file: Path = Path(__file__).parent / "output/strategy_comparison_results.csv" + with csv_file.open("w", newline="") as f: + writer = csv.writer(f) + writer.writerow( + [ + "Test Name", + "Strategy", + "Success", + "Images", + "Internal Links", + "External Links", + "Cleaned HTML Length", + "Execution Time", ] - table_data.append(current_row) - - # Add difference summary if any - if differences: - table_data.append( - ["", "⚠️ Differences", ", ".join(differences), "", "", "", "", ""] + ) + + for result in results: + writer.writerow(asdict(result)) + + +def scrapper_params() -> List[ParameterSet]: + test_cases = [ + ("Basic Extraction", {}), + ("Exclude Tags", {"excluded_tags": ["table", "div.infobox", "div.navbox"]}), + ("Word Threshold", {"word_count_threshold": 50}), + ("CSS Selector", {"css_selector": "div.mw-parser-output > p"}), + ( + "Link Exclusions", + { + "exclude_external_links": True, + "exclude_social_media_links": True, + "exclude_domains": ["facebook.com", "twitter.com"], + }, + ), + ( + "Media Handling", + { + "exclude_external_images": True, + "image_description_min_word_threshold": 20, + }, + ), + ("Text Only", {"only_text": True, "remove_forms": True}), + ("HTML Cleaning", {"clean_html": True, "keep_data_attributes": True}), + ( + "HTML2Text Options", + { + "html2text": { + "skip_internal_links": True, + "single_line_break": True, + "mark_code": True, + "preserve_tags": ["pre", "code"], + } + }, + ), + ] + params: List[ParameterSet] = [] + for strategy_name, strategy in [ + ("new", WebScrapingStrategy()), + ("current", WebScrapingStrategyCurrent()), + ]: + for name, kwargs in test_cases: + params.append( + pytest.param( + name, + strategy_name, + strategy, + kwargs, + id=f"{name} - {strategy_name}", ) + ) - # Add empty row for better readability - table_data.append([""] * len(headers)) - - print("\nStrategy Comparison Results:") - print(tabulate(table_data, headers=headers, tablefmt="grid")) + return params + + +@pytest.mark.parametrize("name,strategy_name,strategy,kwargs", scrapper_params()) +def test_strategy( + wiki_html: str, + name: str, + strategy_name: str, + strategy: ContentScrapingStrategy, + kwargs: dict[str, Any], +): + start_time = time.time() + result: ScrapingResult = strategy.scrap( + url="https://en.wikipedia.org/wiki/Test", html=wiki_html, **kwargs + ) + assert result.success + execution_time = time.time() - start_time + + results.append( + Result( + name=name, + strategy=strategy_name, + success=result.success, + images=len(result.media.images), + internal_links=len(result.links.internal), + external_links=len(result.links.external), + cleaned_html_length=len(result.cleaned_html), + execution_time=execution_time, + ) + ) if __name__ == "__main__": - tester = StrategyTester() - tester.run_all_tests() + import subprocess + + sys.exit(subprocess.call(["pytest", *sys.argv[1:], sys.argv[0]])) diff --git a/tests/async/test_crawler_strategy.py b/tests/async/test_crawler_strategy.py index 337b5aaa8..8501bcd56 100644 --- a/tests/async/test_crawler_strategy.py +++ b/tests/async/test_crawler_strategy.py @@ -1,11 +1,8 @@ -import os import sys -import pytest -# Add the parent directory to the Python path -parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) -sys.path.append(parent_dir) +import pytest +from crawl4ai import CacheMode from crawl4ai.async_webcrawler import AsyncWebCrawler @@ -15,7 +12,7 @@ async def test_custom_user_agent(): custom_user_agent = "MyCustomUserAgent/1.0" crawler.crawler_strategy.update_user_agent(custom_user_agent) url = "https://httpbin.org/user-agent" - result = await crawler.arun(url=url, bypass_cache=True) + result = await crawler.arun(url=url, cache_mode=CacheMode.BYPASS) assert result.success assert custom_user_agent in result.html @@ -26,7 +23,7 @@ async def test_custom_headers(): custom_headers = {"X-Test-Header": "TestValue"} crawler.crawler_strategy.set_custom_headers(custom_headers) url = "https://httpbin.org/headers" - result = await crawler.arun(url=url, bypass_cache=True) + result = await crawler.arun(url=url, cache_mode=CacheMode.BYPASS) assert result.success assert "X-Test-Header" in result.html assert "TestValue" in result.html @@ -37,7 +34,7 @@ async def test_javascript_execution(): async with AsyncWebCrawler(verbose=True) as crawler: js_code = "document.body.innerHTML = '

Modified by JS

';" url = "https://www.example.com" - result = await crawler.arun(url=url, bypass_cache=True, js_code=js_code) + result = await crawler.arun(url=url, cache_mode=CacheMode.BYPASS, js_code=js_code) assert result.success assert "

Modified by JS

" in result.html @@ -46,13 +43,13 @@ async def test_javascript_execution(): async def test_hook_execution(): async with AsyncWebCrawler(verbose=True) as crawler: - async def test_hook(page): + async def test_hook(page, **kwargs): await page.evaluate("document.body.style.backgroundColor = 'red';") return page crawler.crawler_strategy.set_hook("after_goto", test_hook) url = "https://www.example.com" - result = await crawler.arun(url=url, bypass_cache=True) + result = await crawler.arun(url=url, cache_mode=CacheMode.BYPASS) assert result.success assert "background-color: red" in result.html @@ -61,7 +58,7 @@ async def test_hook(page): async def test_screenshot(): async with AsyncWebCrawler(verbose=True) as crawler: url = "https://www.example.com" - result = await crawler.arun(url=url, bypass_cache=True, screenshot=True) + result = await crawler.arun(url=url, cache_mode=CacheMode.BYPASS, screenshot=True) assert result.success assert result.screenshot assert isinstance(result.screenshot, str) @@ -70,4 +67,6 @@ async def test_screenshot(): # Entry point for debugging if __name__ == "__main__": - pytest.main([__file__, "-v"]) + import subprocess + + sys.exit(subprocess.call(["pytest", *sys.argv[1:], sys.argv[0]])) diff --git a/tests/async/test_database_operations.py b/tests/async/test_database_operations.py index db0d328ed..3632579f1 100644 --- a/tests/async/test_database_operations.py +++ b/tests/async/test_database_operations.py @@ -1,11 +1,8 @@ -import os import sys -import pytest -# Add the parent directory to the Python path -parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) -sys.path.append(parent_dir) +import pytest +from crawl4ai import CacheMode from crawl4ai.async_webcrawler import AsyncWebCrawler @@ -14,7 +11,7 @@ async def test_cache_url(): async with AsyncWebCrawler(verbose=True) as crawler: url = "https://www.example.com" # First run to cache the URL - result1 = await crawler.arun(url=url, bypass_cache=True) + result1 = await crawler.arun(url=url, cache_mode=CacheMode.BYPASS) assert result1.success # Second run to retrieve from cache @@ -28,11 +25,11 @@ async def test_bypass_cache(): async with AsyncWebCrawler(verbose=True) as crawler: url = "https://www.python.org" # First run to cache the URL - result1 = await crawler.arun(url=url, bypass_cache=True) + result1 = await crawler.arun(url=url, cache_mode=CacheMode.BYPASS) assert result1.success # Second run bypassing cache - result2 = await crawler.arun(url=url, bypass_cache=True) + result2 = await crawler.arun(url=url, cache_mode=CacheMode.BYPASS) assert result2.success assert ( result2.html != result1.html @@ -42,10 +39,11 @@ async def test_bypass_cache(): @pytest.mark.asyncio async def test_cache_size(): async with AsyncWebCrawler(verbose=True) as crawler: + await crawler.aclear_cache() initial_size = await crawler.aget_cache_size() url = "https://www.nbcnews.com/business" - await crawler.arun(url=url, bypass_cache=True) + await crawler.arun(url=url, cache_mode=CacheMode.ENABLED) new_size = await crawler.aget_cache_size() assert new_size == initial_size + 1 @@ -55,7 +53,7 @@ async def test_cache_size(): async def test_clear_cache(): async with AsyncWebCrawler(verbose=True) as crawler: url = "https://www.example.org" - await crawler.arun(url=url, bypass_cache=True) + await crawler.arun(url=url, cache_mode=CacheMode.ENABLED) initial_size = await crawler.aget_cache_size() assert initial_size > 0 @@ -69,7 +67,8 @@ async def test_clear_cache(): async def test_flush_cache(): async with AsyncWebCrawler(verbose=True) as crawler: url = "https://www.example.net" - await crawler.arun(url=url, bypass_cache=True) + result = await crawler.arun(url=url, cache_mode=CacheMode.ENABLED) + assert result and result.success initial_size = await crawler.aget_cache_size() assert initial_size > 0 @@ -87,4 +86,6 @@ async def test_flush_cache(): # Entry point for debugging if __name__ == "__main__": - pytest.main([__file__, "-v"]) + import subprocess + + sys.exit(subprocess.call(["pytest", *sys.argv[1:], sys.argv[0]])) diff --git a/tests/async/test_dispatchers.py b/tests/async/test_dispatchers.py index 99cf4a989..7caed748c 100644 --- a/tests/async/test_dispatchers.py +++ b/tests/async/test_dispatchers.py @@ -1,15 +1,18 @@ -import pytest +import sys import time + +from httpx import codes +import pytest + from crawl4ai import ( AsyncWebCrawler, BrowserConfig, + CacheMode, + CrawlerMonitor, CrawlerRunConfig, MemoryAdaptiveDispatcher, - SemaphoreDispatcher, RateLimiter, - CrawlerMonitor, - DisplayMode, - CacheMode, + SemaphoreDispatcher, ) @@ -37,7 +40,7 @@ class TestDispatchStrategies: async def test_memory_adaptive_basic(self, browser_config, run_config, test_urls): async with AsyncWebCrawler(config=browser_config) as crawler: dispatcher = MemoryAdaptiveDispatcher( - memory_threshold_percent=70.0, max_session_permit=2, check_interval=0.1 + memory_threshold_percent=80.0, max_session_permit=2, check_interval=0.1 ) results = await crawler.arun_many( test_urls, config=run_config, dispatcher=dispatcher @@ -50,7 +53,7 @@ async def test_memory_adaptive_with_rate_limit( ): async with AsyncWebCrawler(config=browser_config) as crawler: dispatcher = MemoryAdaptiveDispatcher( - memory_threshold_percent=70.0, + memory_threshold_percent=80.0, max_session_permit=2, check_interval=0.1, rate_limiter=RateLimiter( @@ -88,6 +91,7 @@ async def test_semaphore_with_rate_limit( assert len(results) == len(test_urls) assert all(r.success for r in results) + @pytest.mark.skip(reason="memory_wait_timeout is not a valid MemoryAdaptiveDispatcher parameter") async def test_memory_adaptive_memory_error( self, browser_config, run_config, test_urls ): @@ -140,7 +144,7 @@ async def test_rate_limit_backoff(self, browser_config, run_config): base_delay=(0.1, 0.2), max_delay=1.0, max_retries=2, - rate_limit_codes=[200], # Force rate limiting for testing + rate_limit_codes=[codes.OK], # Force rate limiting for testing ), ) start_time = time.time() @@ -151,11 +155,10 @@ async def test_rate_limit_backoff(self, browser_config, run_config): assert len(results) == len(urls) assert duration > 1.0 # Ensure rate limiting caused delays + @pytest.mark.skip(reason="max_visible_rows is not a valid CrawlerMonitor parameter") async def test_monitor_integration(self, browser_config, run_config, test_urls): async with AsyncWebCrawler(config=browser_config) as crawler: - monitor = CrawlerMonitor( - max_visible_rows=5, display_mode=DisplayMode.DETAILED - ) + monitor = CrawlerMonitor(urls_total=5) dispatcher = MemoryAdaptiveDispatcher(max_session_permit=2, monitor=monitor) results = await crawler.arun_many( test_urls, config=run_config, dispatcher=dispatcher @@ -163,8 +166,10 @@ async def test_monitor_integration(self, browser_config, run_config, test_urls): assert len(results) == len(test_urls) # Check monitor stats assert len(monitor.stats) == len(test_urls) - assert all(stat.end_time is not None for stat in monitor.stats.values()) + assert all(stat["end_time"] is not None for stat in monitor.stats.values()) if __name__ == "__main__": - pytest.main([__file__, "-v", "--asyncio-mode=auto"]) + import subprocess + + sys.exit(subprocess.call(["pytest", *sys.argv[1:], sys.argv[0]])) diff --git a/tests/async/test_edge_cases.py b/tests/async/test_edge_cases.py index d3adb53c5..8a1cc5274 100644 --- a/tests/async/test_edge_cases.py +++ b/tests/async/test_edge_cases.py @@ -1,63 +1,67 @@ -import os +import asyncio import re import sys + import pytest from bs4 import BeautifulSoup -import asyncio - -# Add the parent directory to the Python path -parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) -sys.path.append(parent_dir) +from playwright.async_api import Page, BrowserContext +from crawl4ai import CacheMode from crawl4ai.async_webcrawler import AsyncWebCrawler -# @pytest.mark.asyncio -# async def test_large_content_page(): -# async with AsyncWebCrawler(verbose=True) as crawler: -# url = "https://en.wikipedia.org/wiki/List_of_largest_known_stars" # A page with a large table -# result = await crawler.arun(url=url, bypass_cache=True) -# assert result.success -# assert len(result.html) > 1000000 # Expecting more than 1MB of content - -# @pytest.mark.asyncio -# async def test_minimal_content_page(): -# async with AsyncWebCrawler(verbose=True) as crawler: -# url = "https://example.com" # A very simple page -# result = await crawler.arun(url=url, bypass_cache=True) -# assert result.success -# assert len(result.html) < 10000 # Expecting less than 10KB of content - -# @pytest.mark.asyncio -# async def test_single_page_application(): -# async with AsyncWebCrawler(verbose=True) as crawler: -# url = "https://reactjs.org/" # React's website is a SPA -# result = await crawler.arun(url=url, bypass_cache=True) -# assert result.success -# assert "react" in result.html.lower() - -# @pytest.mark.asyncio -# async def test_page_with_infinite_scroll(): -# async with AsyncWebCrawler(verbose=True) as crawler: -# url = "https://news.ycombinator.com/" # Hacker News has infinite scroll -# result = await crawler.arun(url=url, bypass_cache=True) -# assert result.success -# assert "hacker news" in result.html.lower() - -# @pytest.mark.asyncio -# async def test_page_with_heavy_javascript(): -# async with AsyncWebCrawler(verbose=True) as crawler: -# url = "https://www.airbnb.com/" # Airbnb uses a lot of JavaScript -# result = await crawler.arun(url=url, bypass_cache=True) -# assert result.success -# assert "airbnb" in result.html.lower() - -# @pytest.mark.asyncio -# async def test_page_with_mixed_content(): -# async with AsyncWebCrawler(verbose=True) as crawler: -# url = "https://github.com/" # GitHub has a mix of static and dynamic content -# result = await crawler.arun(url=url, bypass_cache=True) -# assert result.success -# assert "github" in result.html.lower() + +@pytest.mark.asyncio +async def test_large_content_page(): + async with AsyncWebCrawler(verbose=True) as crawler: + url = "https://en.wikipedia.org/wiki/List_of_largest_known_stars" # A page with a large table + result = await crawler.arun(url=url, cache_mode=CacheMode.BYPASS) + assert result.success + assert len(result.html) > 1000000 # Expecting more than 1MB of content + + +@pytest.mark.asyncio +async def test_minimal_content_page(): + async with AsyncWebCrawler(verbose=True) as crawler: + url = "https://example.com" # A very simple page + result = await crawler.arun(url=url, cache_mode=CacheMode.BYPASS) + assert result.success + assert len(result.html) < 10000 # Expecting less than 10KB of content + + +@pytest.mark.asyncio +async def test_single_page_application(): + async with AsyncWebCrawler(verbose=True) as crawler: + url = "https://reactjs.org/" # React's website is a SPA + result = await crawler.arun(url=url, cache_mode=CacheMode.BYPASS) + assert result.success + assert "react" in result.html.lower() + + +@pytest.mark.asyncio +async def test_page_with_infinite_scroll(): + async with AsyncWebCrawler(verbose=True) as crawler: + url = "https://news.ycombinator.com/" # Hacker News has infinite scroll + result = await crawler.arun(url=url, cache_mode=CacheMode.BYPASS) + assert result.success + assert "hacker news" in result.html.lower() + + +@pytest.mark.asyncio +async def test_page_with_heavy_javascript(): + async with AsyncWebCrawler(verbose=True) as crawler: + url = "https://www.airbnb.com/" # Airbnb uses a lot of JavaScript + result = await crawler.arun(url=url, cache_mode=CacheMode.BYPASS) + assert result.success + assert "airbnb" in result.html.lower() + + +@pytest.mark.asyncio +async def test_page_with_mixed_content(): + async with AsyncWebCrawler(verbose=True) as crawler: + url = "https://github.com/" # GitHub has a mix of static and dynamic content + result = await crawler.arun(url=url, cache_mode=CacheMode.BYPASS) + assert result.success + assert "github" in result.html.lower() # Add this test to your existing test file @@ -65,13 +69,13 @@ async def test_typescript_commits_multi_page(): first_commit = "" - async def on_execution_started(page): + async def on_execution_started(page: Page, context: BrowserContext, **kwargs): nonlocal first_commit try: - # Check if the page firct commit h4 text is different from the first commit (use document.querySelector('li.Box-sc-g0xbh4-0 h4')) + # Check if the page first commit h4 text is different from the first commit (use document.querySelector('div.Box-sc-g0xbh4-0 h4')) while True: - await page.wait_for_selector("li.Box-sc-g0xbh4-0 h4") - commit = await page.query_selector("li.Box-sc-g0xbh4-0 h4") + await page.wait_for_selector("div.Box-sc-g0xbh4-0 h4") + commit = await page.query_selector("div.Box-sc-g0xbh4-0 h4") commit = await commit.evaluate("(element) => element.textContent") commit = re.sub(r"\s+", "", commit) if commit and commit != first_commit: @@ -97,16 +101,14 @@ async def on_execution_started(page): result = await crawler.arun( url=url, # Only use URL for the first page session_id=session_id, - css_selector="li.Box-sc-g0xbh4-0", - js=js_next_page - if page > 0 - else None, # Don't click 'next' on the first page - bypass_cache=True, + css_selector="div.Box-sc-g0xbh4-0", + js_code=js_next_page if page > 0 else None, # Don't click 'next' on the first page + cache_mode=CacheMode.BYPASS, js_only=page > 0, # Use js_only for subsequent pages ) assert result.success, f"Failed to crawl page {page + 1}" - + assert result.cleaned_html, f"No cleaned HTML found for page {page + 1}" # Parse the HTML and extract commits soup = BeautifulSoup(result.cleaned_html, "html.parser") commits = soup.select("li") @@ -121,13 +123,13 @@ async def on_execution_started(page): await crawler.crawler_strategy.kill_session(session_id) # Assertions - assert ( - len(all_commits) >= 90 - ), f"Expected at least 90 commits, but got {len(all_commits)}" + assert len(all_commits) >= 90, f"Expected at least 90 commits, but got {len(all_commits)}" print(f"Successfully crawled {len(all_commits)} commits across 3 pages") # Entry point for debugging if __name__ == "__main__": - pytest.main([__file__, "-v"]) + import subprocess + + sys.exit(subprocess.call(["pytest", *sys.argv[1:], sys.argv[0]])) diff --git a/tests/async/test_error_handling.py b/tests/async/test_error_handling.py index ae4af6c84..ca06722da 100644 --- a/tests/async/test_error_handling.py +++ b/tests/async/test_error_handling.py @@ -1,78 +1,83 @@ -# import os -# import sys -# import pytest -# import asyncio - -# # Add the parent directory to the Python path -# parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) -# sys.path.append(parent_dir) - -# from crawl4ai.async_webcrawler import AsyncWebCrawler -# from crawl4ai.utils import InvalidCSSSelectorError - -# class AsyncCrawlerWrapper: -# def __init__(self): -# self.crawler = None - -# async def setup(self): -# self.crawler = AsyncWebCrawler(verbose=True) -# await self.crawler.awarmup() - -# async def cleanup(self): -# if self.crawler: -# await self.crawler.aclear_cache() - -# @pytest.fixture(scope="module") -# def crawler_wrapper(): -# wrapper = AsyncCrawlerWrapper() -# asyncio.get_event_loop().run_until_complete(wrapper.setup()) -# yield wrapper -# asyncio.get_event_loop().run_until_complete(wrapper.cleanup()) - -# @pytest.mark.asyncio -# async def test_network_error(crawler_wrapper): -# url = "https://www.nonexistentwebsite123456789.com" -# result = await crawler_wrapper.crawler.arun(url=url, bypass_cache=True) -# assert not result.success -# assert "Failed to crawl" in result.error_message - -# # @pytest.mark.asyncio -# # async def test_timeout_error(crawler_wrapper): -# # # Simulating a timeout by using a very short timeout value -# # url = "https://www.nbcnews.com/business" -# # result = await crawler_wrapper.crawler.arun(url=url, bypass_cache=True, timeout=0.001) -# # assert not result.success -# # assert "timeout" in result.error_message.lower() - -# # @pytest.mark.asyncio -# # async def test_invalid_css_selector(crawler_wrapper): -# # url = "https://www.nbcnews.com/business" -# # with pytest.raises(InvalidCSSSelectorError): -# # await crawler_wrapper.crawler.arun(url=url, bypass_cache=True, css_selector="invalid>>selector") - -# # @pytest.mark.asyncio -# # async def test_js_execution_error(crawler_wrapper): -# # url = "https://www.nbcnews.com/business" -# # invalid_js = "This is not valid JavaScript code;" -# # result = await crawler_wrapper.crawler.arun(url=url, bypass_cache=True, js=invalid_js) -# # assert not result.success -# # assert "JavaScript" in result.error_message - -# # @pytest.mark.asyncio -# # async def test_empty_page(crawler_wrapper): -# # # Use a URL that typically returns an empty page -# # url = "http://example.com/empty" -# # result = await crawler_wrapper.crawler.arun(url=url, bypass_cache=True) -# # assert result.success # The crawl itself should succeed -# # assert not result.markdown.strip() # The markdown content should be empty or just whitespace - -# # @pytest.mark.asyncio -# # async def test_rate_limiting(crawler_wrapper): -# # # Simulate rate limiting by making multiple rapid requests -# # url = "https://www.nbcnews.com/business" -# # results = await asyncio.gather(*[crawler_wrapper.crawler.arun(url=url, bypass_cache=True) for _ in range(10)]) -# # assert any(not result.success and "rate limit" in result.error_message.lower() for result in results) - -# # Entry point for debugging -# if __name__ == "__main__": -# pytest.main([__file__, "-v"]) +import asyncio +import sys + +import pytest +import pytest_asyncio + +from crawl4ai import CacheMode +from crawl4ai.async_webcrawler import AsyncWebCrawler +from crawl4ai.utils import InvalidCSSSelectorError + + +@pytest_asyncio.fixture +async def crawler(): + async with AsyncWebCrawler(verbose=True, warmup=False) as crawler: + yield crawler + + +@pytest.mark.asyncio +async def test_network_error(crawler: AsyncWebCrawler): + url = "https://www.nonexistentwebsite123456789.com" + result = await crawler.arun(url=url, cache_mode=CacheMode.BYPASS) + assert not result.success + assert result.error_message + assert "Failed on navigating ACS-GOTO" in result.error_message + + +@pytest.mark.asyncio +async def test_timeout_error(crawler: AsyncWebCrawler): + # Simulating a timeout by using a very short timeout value + url = "https://www.nbcnews.com/business" + result = await crawler.arun(url=url, cache_mode=CacheMode.BYPASS, page_timeout=0.001) + assert not result.success + assert result.error_message + assert "timeout" in result.error_message.lower() + + +@pytest.mark.asyncio +@pytest.mark.skip("Invalid CSS selector not raised any more") +async def test_invalid_css_selector(crawler: AsyncWebCrawler): + url = "https://www.nbcnews.com/business" + with pytest.raises(InvalidCSSSelectorError): + await crawler.arun(url=url, cache_mode=CacheMode.BYPASS, css_selector="invalid>>selector") + + +@pytest.mark.asyncio +async def test_js_execution_error(crawler: AsyncWebCrawler): + url = "https://www.nbcnews.com/business" + invalid_js = "This is not valid JavaScript code;" + result = await crawler.arun(url=url, cache_mode=CacheMode.BYPASS, js_code=invalid_js) + assert result.success + assert result.js_execution_result + assert result.js_execution_result["success"] + results: list[dict] = result.js_execution_result["results"] + assert results + assert not results[0]["success"] + assert "SyntaxError" in results[0]["error"] + + +@pytest.mark.asyncio +@pytest.mark.skip("The page is not empty any more") +async def test_empty_page(crawler: AsyncWebCrawler): + # Use a URL that typically returns an empty page + url = "http://example.com/empty" + result = await crawler.arun(url=url, cache_mode=CacheMode.BYPASS) + assert result.success # The crawl itself should succeed + assert result.markdown is not None + assert not result.markdown.strip() # The markdown content should be empty or just whitespace + + +@pytest.mark.asyncio +@pytest.mark.skip("Rate limiting doesn't trigger") +async def test_rate_limiting(crawler: AsyncWebCrawler): + # Simulate rate limiting by making multiple rapid requests + url = "https://www.nbcnews.com/business" + results = await asyncio.gather(*[crawler.arun(url=url, cache_mode=CacheMode.BYPASS) for _ in range(10)]) + assert any(not result.success and result.error_message and "rate limit" in result.error_message.lower() for result in results) + + +# Entry point for debugging +if __name__ == "__main__": + import subprocess + + sys.exit(subprocess.call(["pytest", *sys.argv[1:], sys.argv[0]])) diff --git a/tests/async/test_evaluation_scraping_methods_performance.configs.py b/tests/async/test_evaluation_scraping_methods_performance_configs.py similarity index 52% rename from tests/async/test_evaluation_scraping_methods_performance.configs.py rename to tests/async/test_evaluation_scraping_methods_performance_configs.py index 797cf681c..e714a52f1 100644 --- a/tests/async/test_evaluation_scraping_methods_performance.configs.py +++ b/tests/async/test_evaluation_scraping_methods_performance_configs.py @@ -1,13 +1,20 @@ -import json +import difflib +import sys import time +from typing import Any, List, Tuple + +import pytest +from _pytest.mark import ParameterSet from bs4 import BeautifulSoup +from bs4.element import Tag +from lxml import etree +from lxml import html as lhtml + from crawl4ai.content_scraping_strategy import ( - WebScrapingStrategy, LXMLWebScrapingStrategy, + WebScrapingStrategy, ) -from typing import Dict, List, Tuple -import difflib -from lxml import html as lhtml, etree +from crawl4ai.models import Links, Media, ScrapingResult def normalize_dom(element): @@ -197,7 +204,7 @@ def generate_complicated_html(): Complicated Test Page - + + + + +
+

Example Domain

+

This domain is for use in illustrative examples in documents. You may use this + domain in literature without prior coordination or asking for permission.

+

More information...

+
+ +""") + + app = web.Application() + app.router.add_get("/{route}", handler) + return await aiohttp_server(app) + +@pytest_asyncio.fixture +async def browsers_manager(tmp_path: Path) -> AsyncGenerator[BrowsersManager, None]: + manager: BrowsersManager = BrowsersManager() + await manager.run(tmp_path) + yield manager + await manager.close() + + +@pytest_asyncio.fixture +async def manager() -> AsyncGenerator[BrowserManager, None]: + browser_config: BrowserConfig = BrowserConfig(browser_mode="builtin", headless=True, verbose=True) + manager: BrowserManager = BrowserManager(browser_config=browser_config) + await manager.start() + yield manager + await manager.close() + +@pytest.mark.asyncio +@pytest.mark.parametrize("browser_type", ["webkit", "firefox"]) +async def test_not_supported(tmp_path: Path, browser_type: str): + browser_config: BrowserConfig = BrowserConfig( + browser_mode="builtin", + browser_type=browser_type, + headless=True, + debugging_port=0, + user_data_dir=str(tmp_path), + ) + logger: AsyncLogger = AsyncLogger() + manager: BrowserManager = BrowserManager(browser_config=browser_config, logger=logger) + assert isinstance(manager._strategy, BuiltinBrowserStrategy), f"Wrong strategy type {manager._strategy.__class__.__name__}" + manager._strategy.shutting_down = True + with pytest.raises(Exception): + await manager.start() - # Step 3: Start the manager to launch or connect to builtin browser - print(f"\n{INFO}3. Starting the browser manager{RESET}") +@pytest.mark.asyncio +@pytest.mark.parametrize("browser_type", ["chromium"]) +async def test_ephemeral_port(tmp_path: Path, browser_type: str): + browser_config: BrowserConfig = BrowserConfig( + browser_mode="builtin", + browser_type=browser_type, + headless=True, + debugging_port=0, + user_data_dir=str(tmp_path), + ) + logger: AsyncLogger = AsyncLogger() + manager: BrowserManager = BrowserManager(browser_config=browser_config, logger=logger) + assert isinstance(manager._strategy, BuiltinBrowserStrategy), f"Wrong strategy type {manager._strategy.__class__.__name__}" + manager._strategy.shutting_down = True try: await manager.start() - print(f"{SUCCESS}Browser manager started successfully{RESET}") - except Exception as e: - print(f"{ERROR}Failed to start browser manager: {str(e)}{RESET}") - return None - - # Step 4: Get browser info from the strategy - print(f"\n{INFO}4. Getting browser information{RESET}") - browser_info = manager._strategy.get_builtin_browser_info() - if browser_info: - print(f"{SUCCESS}Browser info retrieved:{RESET}") - for key, value in browser_info.items(): - if key != "config": # Skip the verbose config section - print(f" {key}: {value}") - - cdp_url = browser_info.get("cdp_url") - print(f"{SUCCESS}CDP URL: {cdp_url}{RESET}") - else: - print(f"{ERROR}Failed to get browser information{RESET}") - cdp_url = None - - # Save manager for later tests - return manager, cdp_url + assert manager._strategy.config.debugging_port != 0, "Ephemeral port not assigned" + finally: + await manager.close() +@pytest.mark.asyncio +async def test_builtin_browser_creation(manager: BrowserManager): + """Test creating a builtin browser using the BrowserManager with BuiltinBrowserStrategy""" + # Check if we have a BuiltinBrowserStrategy + assert isinstance(manager._strategy, BuiltinBrowserStrategy), f"Wrong strategy type {manager._strategy.__class__.__name__}" + + # Check we can get browser info from the strategy + strategy: BuiltinBrowserStrategy = manager._strategy + browser_info = strategy.get_browser_info() + assert browser_info, "Failed to get browser info" +@pytest.mark.asyncio async def test_page_operations(manager: BrowserManager): """Test page operations with the builtin browser""" - print( - f"\n{INFO}========== Testing Page Operations with Builtin Browser =========={RESET}" - ) # Step 1: Get a single page - print(f"\n{INFO}1. Getting a single page{RESET}") - try: - crawler_config = CrawlerRunConfig() - page, context = await manager.get_page(crawler_config) - print(f"{SUCCESS}Got page successfully{RESET}") + crawler_config = CrawlerRunConfig() + page, context = await manager.get_page(crawler_config) - # Navigate to a test URL - await page.goto("https://example.com") - title = await page.title() - print(f"{SUCCESS}Page title: {title}{RESET}") + # Navigate to a test URL + await page.goto("https://example.com") + title = await page.title() - # Close the page - await page.close() - print(f"{SUCCESS}Page closed successfully{RESET}") - except Exception as e: - print(f"{ERROR}Page operation failed: {str(e)}{RESET}") - return False + # Close the page + await page.close() # Step 2: Get multiple pages - print(f"\n{INFO}2. Getting multiple pages with get_pages(){RESET}") - try: - # Request 3 pages - crawler_config = CrawlerRunConfig() - pages = await manager.get_pages(crawler_config, count=3) - print(f"{SUCCESS}Got {len(pages)} pages{RESET}") - - # Test each page - for i, (page, context) in enumerate(pages): - await page.goto(f"https://example.com?test={i}") - title = await page.title() - print(f"{SUCCESS}Page {i + 1} title: {title}{RESET}") - await page.close() - - print(f"{SUCCESS}All pages tested and closed successfully{RESET}") - except Exception as e: - print(f"{ERROR}Multiple page operation failed: {str(e)}{RESET}") - return False - return True + # Request 3 pages + crawler_config = CrawlerRunConfig() + pages = await manager.get_pages(crawler_config, count=3) + # Test each page + for i, (page, context) in enumerate(pages): + response = await page.goto(f"https://example.com?test={i}") + assert response, f"Failed to load page {i + 1}" + title: str = await page.title() + assert title == "Example Domain", f"Expected title 'Example Domain', got '{title}'" + await page.close() +@pytest.mark.asyncio async def test_browser_status_management(manager: BrowserManager): """Test browser status and management operations""" - print(f"\n{INFO}========== Testing Browser Status and Management =========={RESET}") - - # Step 1: Get browser status - print(f"\n{INFO}1. Getting browser status{RESET}") - try: - status = await manager._strategy.get_builtin_browser_status() - print(f"{SUCCESS}Browser status:{RESET}") - print(f" Running: {status['running']}") - print(f" CDP URL: {status['cdp_url']}") - except Exception as e: - print(f"{ERROR}Failed to get browser status: {str(e)}{RESET}") - return False + assert isinstance(manager._strategy, BuiltinBrowserStrategy), f"Wrong strategy type {manager._strategy.__class__.__name__}" + status = await manager._strategy.get_builtin_browser_status() # Step 2: Test killing the browser - print(f"\n{INFO}2. Testing killing the browser{RESET}") - try: - result = await manager._strategy.kill_builtin_browser() - if result: - print(f"{SUCCESS}Browser killed successfully{RESET}") - else: - print(f"{ERROR}Failed to kill browser{RESET}") - except Exception as e: - print(f"{ERROR}Browser kill operation failed: {str(e)}{RESET}") - return False + result = await manager._strategy.kill_builtin_browser() + assert result, "Failed to kill the browser" # Step 3: Check status after kill - print(f"\n{INFO}3. Checking status after kill{RESET}") - try: - status = await manager._strategy.get_builtin_browser_status() - if not status["running"]: - print(f"{SUCCESS}Browser is correctly reported as not running{RESET}") - else: - print(f"{ERROR}Browser is incorrectly reported as still running{RESET}") - except Exception as e: - print(f"{ERROR}Failed to get browser status: {str(e)}{RESET}") - return False + status = await manager._strategy.get_builtin_browser_status() + assert status, "Failed to get browser status after kill" + assert not status["running"], "Browser is still running after kill" # Step 4: Launch a new browser - print(f"\n{INFO}4. Launching a new browser{RESET}") - try: - cdp_url = await manager._strategy.launch_builtin_browser( - browser_type="chromium", headless=True - ) - if cdp_url: - print(f"{SUCCESS}New browser launched at: {cdp_url}{RESET}") - else: - print(f"{ERROR}Failed to launch new browser{RESET}") - return False - except Exception as e: - print(f"{ERROR}Browser launch failed: {str(e)}{RESET}") - return False - - return True - + cdp_url = await manager._strategy.launch_builtin_browser( + browser_type="chromium", headless=True + ) + assert cdp_url, "Failed to launch a new browser" +@pytest.mark.asyncio async def test_multiple_managers(): """Test creating multiple BrowserManagers that use the same builtin browser""" - print(f"\n{INFO}========== Testing Multiple Browser Managers =========={RESET}") # Step 1: Create first manager - print(f"\n{INFO}1. Creating first browser manager{RESET}") - browser_config1 = (BrowserConfig(browser_mode="builtin", headless=True),) - manager1 = BrowserManager(browser_config=browser_config1, logger=logger) + browser_config1: BrowserConfig = BrowserConfig(browser_mode="builtin", headless=True) + manager1: BrowserManager = BrowserManager(browser_config=browser_config1) # Step 2: Create second manager - print(f"\n{INFO}2. Creating second browser manager{RESET}") - browser_config2 = BrowserConfig(browser_mode="builtin", headless=True) - manager2 = BrowserManager(browser_config=browser_config2, logger=logger) + browser_config2: BrowserConfig = BrowserConfig(browser_mode="builtin", headless=True) + manager2: BrowserManager = BrowserManager(browser_config=browser_config2) # Step 3: Start both managers (should connect to the same builtin browser) - print(f"\n{INFO}3. Starting both managers{RESET}") + page1: Optional[Page] = None + page2: Optional[Page] = None try: await manager1.start() - print(f"{SUCCESS}First manager started{RESET}") - await manager2.start() - print(f"{SUCCESS}Second manager started{RESET}") # Check if they got the same CDP URL cdp_url1 = manager1._strategy.config.cdp_url cdp_url2 = manager2._strategy.config.cdp_url - if cdp_url1 == cdp_url2: - print( - f"{SUCCESS}Both managers connected to the same browser: {cdp_url1}{RESET}" - ) - else: - print( - f"{WARNING}Managers connected to different browsers: {cdp_url1} and {cdp_url2}{RESET}" - ) - except Exception as e: - print(f"{ERROR}Failed to start managers: {str(e)}{RESET}") - return False + assert cdp_url1 == cdp_url2, "CDP URLs do not match between managers" - # Step 4: Test using both managers - print(f"\n{INFO}4. Testing operations with both managers{RESET}") - try: + # Step 4: Test using both managers # First manager creates a page page1, ctx1 = await manager1.get_page(CrawlerRunConfig()) await page1.goto("https://example.com") - title1 = await page1.title() - print(f"{SUCCESS}Manager 1 page title: {title1}{RESET}") + title1: str = await page1.title() + assert title1 == "Example Domain", f"Expected title 'Example Domain', got '{title1}'" # Second manager creates a page page2, ctx2 = await manager2.get_page(CrawlerRunConfig()) await page2.goto("https://example.org") - title2 = await page2.title() - print(f"{SUCCESS}Manager 2 page title: {title2}{RESET}") - - # Clean up - await page1.close() - await page2.close() - except Exception as e: - print(f"{ERROR}Failed to use both managers: {str(e)}{RESET}") - return False - - # Step 5: Close both managers - print(f"\n{INFO}5. Closing both managers{RESET}") - try: + title2: str = await page2.title() + assert title2 == "Example Domain", f"Expected title 'Example Domain', got '{title2}'" + finally: + if page1: + await page1.close() + if page2: + await page2.close() + # Close both managers await manager1.close() - print(f"{SUCCESS}First manager closed{RESET}") - await manager2.close() - print(f"{SUCCESS}Second manager closed{RESET}") - except Exception as e: - print(f"{ERROR}Failed to close managers: {str(e)}{RESET}") - return False - return True - - -async def test_edge_cases(): - """Test edge cases like multiple starts, killing browser during operations, etc.""" - print(f"\n{INFO}========== Testing Edge Cases =========={RESET}") - - # Step 1: Test multiple starts with the same manager - print(f"\n{INFO}1. Testing multiple starts with the same manager{RESET}") - browser_config = BrowserConfig(browser_mode="builtin", headless=True) - manager = BrowserManager(browser_config=browser_config, logger=logger) +@pytest.mark.asyncio +async def test_multiple_starts(manager: BrowserManager): + """Test multiple starts with the same manager.""" + page: Optional[Page] = None try: - await manager.start() - print(f"{SUCCESS}First start successful{RESET}") - # Try to start again await manager.start() - print(f"{SUCCESS}Second start completed without errors{RESET}") # Test if it's still functional page, context = await manager.get_page(CrawlerRunConfig()) + assert page is not None, "Failed to create a page after multiple starts" await page.goto("https://example.com") - title = await page.title() - print( - f"{SUCCESS}Page operations work after multiple starts. Title: {title}{RESET}" - ) - await page.close() - except Exception as e: - print(f"{ERROR}Multiple starts test failed: {str(e)}{RESET}") - return False + title: str = await page.title() + assert title == "Example Domain", f"Expected title 'Example Domain', got '{title}'" finally: - await manager.close() - - # Step 2: Test killing the browser while manager is active - print(f"\n{INFO}2. Testing killing the browser while manager is active{RESET}") - manager = BrowserManager(browser_config=browser_config, logger=logger) - - try: - await manager.start() - print(f"{SUCCESS}Manager started{RESET}") - - # Kill the browser directly - print(f"{INFO}Killing the browser...{RESET}") - await manager._strategy.kill_builtin_browser() - print(f"{SUCCESS}Browser killed{RESET}") - - # Try to get a page (should fail or launch a new browser) - try: - page, context = await manager.get_page(CrawlerRunConfig()) - print( - f"{WARNING}Page request succeeded despite killed browser (might have auto-restarted){RESET}" - ) - title = await page.title() - print(f"{SUCCESS}Got page title: {title}{RESET}") + if page: await page.close() - except Exception as e: - print( - f"{SUCCESS}Page request failed as expected after browser was killed: {str(e)}{RESET}" - ) - except Exception as e: - print(f"{ERROR}Kill during operation test failed: {str(e)}{RESET}") - return False - finally: await manager.close() - return True +@pytest.mark.asyncio +async def test_kill_while_active(manager: BrowserManager): + """Test killing the browser while manager is active.""" + assert isinstance(manager._strategy, BuiltinBrowserStrategy), f"Wrong strategy type {manager._strategy.__class__.__name__}" + await manager._strategy.kill_builtin_browser() + with pytest.raises(Exception): + # Try to get a page should fail + await manager.get_page(CrawlerRunConfig()) -async def cleanup_browsers(): - """Clean up any remaining builtin browsers""" - print(f"\n{INFO}========== Cleaning Up Builtin Browsers =========={RESET}") - - browser_config = BrowserConfig(browser_mode="builtin", headless=True) - manager = BrowserManager(browser_config=browser_config, logger=logger) - - try: - # No need to start, just access the strategy directly - strategy = manager._strategy - if isinstance(strategy, BuiltinBrowserStrategy): - result = await strategy.kill_builtin_browser() - if result: - print(f"{SUCCESS}Successfully killed all builtin browsers{RESET}") - else: - print(f"{WARNING}No builtin browsers found to kill{RESET}") - else: - print(f"{ERROR}Wrong strategy type: {strategy.__class__.__name__}{RESET}") - except Exception as e: - print(f"{ERROR}Cleanup failed: {str(e)}{RESET}") - finally: - # Just to be safe - try: - await manager.close() - except: - pass - - -async def test_performance_scaling(): +@pytest.mark.asyncio +@pytest.mark.timeout(30) +async def test_performance_scaling(browsers_manager: BrowsersManager, test_server: TestServer): """Test performance with multiple browsers and pages. This test creates multiple browsers on different ports, spawns multiple pages per browser, and measures performance metrics. """ - print(f"\n{INFO}========== Testing Performance Scaling =========={RESET}") - - # Configuration parameters - num_browsers = 10 - pages_per_browser = 10 - total_pages = num_browsers * pages_per_browser - base_port = 9222 - - # Set up a measuring mechanism for memory - import psutil - import gc - - # Force garbage collection before starting - gc.collect() - process = psutil.Process() - initial_memory = process.memory_info().rss / 1024 / 1024 # in MB - peak_memory = initial_memory - - # Report initial configuration - print( - f"{INFO}Test configuration: {num_browsers} browsers × {pages_per_browser} pages = {total_pages} total crawls{RESET}" - ) - - # List to track managers - managers: List[BrowserManager] = [] - all_pages = [] - - - - # Get crawl4ai home directory - crawl4ai_home = os.path.expanduser("~/.crawl4ai") - temp_dir = os.path.join(crawl4ai_home, "temp") - os.makedirs(temp_dir, exist_ok=True) - - # Create all managers but don't start them yet - manager_configs = [] - for i in range(num_browsers): - port = base_port + i - browser_config = BrowserConfig( - browser_mode="builtin", - headless=True, - debugging_port=port, - user_data_dir=os.path.join(temp_dir, f"browser_profile_{i}"), - ) - manager = BrowserManager(browser_config=browser_config, logger=logger) - manager._strategy.shutting_down = True - manager_configs.append((manager, i, port)) - - # Define async function to start a single manager - async def start_manager(manager, index, port): - try: - await manager.start() - return manager - except Exception as e: - print( - f"{ERROR}Failed to start browser {index + 1} on port {port}: {str(e)}{RESET}" - ) - return None - - # Start all managers in parallel - start_tasks = [ - start_manager(manager, i, port) for manager, i, port in manager_configs - ] - started_managers = await asyncio.gather(*start_tasks) - - # Filter out None values (failed starts) and add to managers list - managers = [m for m in started_managers if m is not None] - - if len(managers) == 0: - print(f"{ERROR}All browser managers failed to start. Aborting test.{RESET}") - return False - - if len(managers) < num_browsers: - print( - f"{WARNING}Only {len(managers)} out of {num_browsers} browser managers started successfully{RESET}" - ) - - # Create pages for each browser - for i, manager in enumerate(managers): - try: - pages = await manager.get_pages(CrawlerRunConfig(), count=pages_per_browser) - all_pages.extend(pages) - except Exception as e: - print(f"{ERROR}Failed to create pages for browser {i + 1}: {str(e)}{RESET}") - - # Check memory after page creation - gc.collect() - current_memory = process.memory_info().rss / 1024 / 1024 - peak_memory = max(peak_memory, current_memory) + assert browsers_manager.managers, "Failed to start any browser managers" # Ask for confirmation before loading confirmation = input( f"{WARNING}Do you want to proceed with loading pages? (y/n): {RESET}" - ) - # Step 1: Create and start multiple browser managers in parallel - start_time = time.time() - - if confirmation.lower() == "y": - load_start_time = time.time() - - # Function to load a single page - async def load_page(page_ctx, index): - page, _ = page_ctx - try: - await page.goto(f"https://example.com/page{index}", timeout=30000) - title = await page.title() - return title - except Exception as e: - return f"Error: {str(e)}" - - # Load all pages concurrently - load_tasks = [load_page(page_ctx, i) for i, page_ctx in enumerate(all_pages)] - load_results = await asyncio.gather(*load_tasks, return_exceptions=True) - - # Count successes and failures - successes = sum( - 1 for r in load_results if isinstance(r, str) and not r.startswith("Error") - ) - failures = len(load_results) - successes - - load_time = time.time() - load_start_time - total_test_time = time.time() - start_time - - # Check memory after loading (peak memory) - gc.collect() - current_memory = process.memory_info().rss / 1024 / 1024 - peak_memory = max(peak_memory, current_memory) - - # Calculate key metrics - memory_per_page = peak_memory / successes if successes > 0 else 0 - time_per_crawl = total_test_time / successes if successes > 0 else 0 - crawls_per_second = successes / total_test_time if total_test_time > 0 else 0 - crawls_per_minute = crawls_per_second * 60 - crawls_per_hour = crawls_per_minute * 60 - - # Print simplified performance summary - from rich.console import Console - from rich.table import Table - - console = Console() - - # Create a simple summary table - table = Table(title="CRAWL4AI PERFORMANCE SUMMARY") + ) if sys.stdin.isatty() else "y" - table.add_column("Metric", style="cyan") - table.add_column("Value", style="green") + assert confirmation.lower() == "y", "User aborted the test" - table.add_row("Total Crawls Completed", f"{successes}") - table.add_row("Total Time", f"{total_test_time:.2f} seconds") - table.add_row("Time Per Crawl", f"{time_per_crawl:.2f} seconds") - table.add_row("Crawling Speed", f"{crawls_per_second:.2f} crawls/second") - table.add_row("Projected Rate (1 minute)", f"{crawls_per_minute:.0f} crawls") - table.add_row("Projected Rate (1 hour)", f"{crawls_per_hour:.0f} crawls") - table.add_row("Peak Memory Usage", f"{peak_memory:.2f} MB") - table.add_row("Memory Per Crawl", f"{memory_per_page:.2f} MB") - - # Display the table - console.print(table) - - # Ask confirmation before cleanup - confirmation = input( - f"{WARNING}Do you want to proceed with cleanup? (y/n): {RESET}" - ) - if confirmation.lower() != "y": - print(f"{WARNING}Cleanup aborted by user{RESET}") - return False - - # Close all pages - for page, _ in all_pages: - try: - await page.close() - except: - pass - - # Close all managers - for manager in managers: - try: - await manager.close() - except: - pass - - # Remove the temp directory - import shutil - - if os.path.exists(temp_dir): - shutil.rmtree(temp_dir) - - return True - - -async def test_performance_scaling_lab( num_browsers: int = 10, pages_per_browser: int = 10): - """Test performance with multiple browsers and pages. - - This test creates multiple browsers on different ports, - spawns multiple pages per browser, and measures performance metrics. - """ - print(f"\n{INFO}========== Testing Performance Scaling =========={RESET}") - - # Configuration parameters - num_browsers = num_browsers - pages_per_browser = pages_per_browser - total_pages = num_browsers * pages_per_browser - base_port = 9222 - - # Set up a measuring mechanism for memory - import psutil - import gc - - # Force garbage collection before starting - gc.collect() - process = psutil.Process() - initial_memory = process.memory_info().rss / 1024 / 1024 # in MB - peak_memory = initial_memory - - # Report initial configuration - print( - f"{INFO}Test configuration: {num_browsers} browsers × {pages_per_browser} pages = {total_pages} total crawls{RESET}" - ) - - # List to track managers - managers: List[BrowserManager] = [] - all_pages = [] - - # Get crawl4ai home directory - crawl4ai_home = os.path.expanduser("~/.crawl4ai") - temp_dir = os.path.join(crawl4ai_home, "temp") - os.makedirs(temp_dir, exist_ok=True) - - # Create all managers but don't start them yet - manager_configs = [] - for i in range(num_browsers): - port = base_port + i - browser_config = BrowserConfig( - browser_mode="builtin", - headless=True, - debugging_port=port, - user_data_dir=os.path.join(temp_dir, f"browser_profile_{i}"), - ) - manager = BrowserManager(browser_config=browser_config, logger=logger) - manager._strategy.shutting_down = True - manager_configs.append((manager, i, port)) - - # Define async function to start a single manager - async def start_manager(manager, index, port): - try: - await manager.start() - return manager - except Exception as e: - print( - f"{ERROR}Failed to start browser {index + 1} on port {port}: {str(e)}{RESET}" - ) - return None - - # Start all managers in parallel - start_tasks = [ - start_manager(manager, i, port) for manager, i, port in manager_configs - ] - started_managers = await asyncio.gather(*start_tasks) - - # Filter out None values (failed starts) and add to managers list - managers = [m for m in started_managers if m is not None] - - if len(managers) == 0: - print(f"{ERROR}All browser managers failed to start. Aborting test.{RESET}") - return False - - if len(managers) < num_browsers: - print( - f"{WARNING}Only {len(managers)} out of {num_browsers} browser managers started successfully{RESET}" - ) - - # Create pages for each browser - for i, manager in enumerate(managers): - try: - pages = await manager.get_pages(CrawlerRunConfig(), count=pages_per_browser) - all_pages.extend(pages) - except Exception as e: - print(f"{ERROR}Failed to create pages for browser {i + 1}: {str(e)}{RESET}") - - # Check memory after page creation - gc.collect() - current_memory = process.memory_info().rss / 1024 / 1024 - peak_memory = max(peak_memory, current_memory) - - # Ask for confirmation before loading - confirmation = input( - f"{WARNING}Do you want to proceed with loading pages? (y/n): {RESET}" - ) # Step 1: Create and start multiple browser managers in parallel start_time = time.time() - - if confirmation.lower() == "y": - load_start_time = time.time() - - # Function to load a single page - async def load_page(page_ctx, index): - page, _ = page_ctx - try: - await page.goto(f"https://example.com/page{index}", timeout=30000) - title = await page.title() - return title - except Exception as e: - return f"Error: {str(e)}" - - # Load all pages concurrently - load_tasks = [load_page(page_ctx, i) for i, page_ctx in enumerate(all_pages)] - load_results = await asyncio.gather(*load_tasks, return_exceptions=True) - - # Count successes and failures - successes = sum( - 1 for r in load_results if isinstance(r, str) and not r.startswith("Error") - ) - failures = len(load_results) - successes - - load_time = time.time() - load_start_time - total_test_time = time.time() - start_time - # Check memory after loading (peak memory) - gc.collect() - current_memory = process.memory_info().rss / 1024 / 1024 - peak_memory = max(peak_memory, current_memory) - - # Calculate key metrics - memory_per_page = peak_memory / successes if successes > 0 else 0 - time_per_crawl = total_test_time / successes if successes > 0 else 0 - crawls_per_second = successes / total_test_time if total_test_time > 0 else 0 - crawls_per_minute = crawls_per_second * 60 - crawls_per_hour = crawls_per_minute * 60 - - # Print simplified performance summary - from rich.console import Console - from rich.table import Table - - console = Console() - - # Create a simple summary table - table = Table(title="CRAWL4AI PERFORMANCE SUMMARY") - - table.add_column("Metric", style="cyan") - table.add_column("Value", style="green") - - table.add_row("Total Crawls Completed", f"{successes}") - table.add_row("Total Time", f"{total_test_time:.2f} seconds") - table.add_row("Time Per Crawl", f"{time_per_crawl:.2f} seconds") - table.add_row("Crawling Speed", f"{crawls_per_second:.2f} crawls/second") - table.add_row("Projected Rate (1 minute)", f"{crawls_per_minute:.0f} crawls") - table.add_row("Projected Rate (1 hour)", f"{crawls_per_hour:.0f} crawls") - table.add_row("Peak Memory Usage", f"{peak_memory:.2f} MB") - table.add_row("Memory Per Crawl", f"{memory_per_page:.2f} MB") - - # Display the table - console.print(table) - - # Ask confirmation before cleanup - confirmation = input( - f"{WARNING}Do you want to proceed with cleanup? (y/n): {RESET}" - ) - if confirmation.lower() != "y": - print(f"{WARNING}Cleanup aborted by user{RESET}") - return False - - # Close all pages - for page, _ in all_pages: + # Function to load a single page + url = test_server.make_url("/page") + async def load_page(page_ctx: Tuple[Page, Any], index): + page, _ = page_ctx try: - await page.close() - except: - pass - - # Close all managers - for manager in managers: - try: - await manager.close() - except: - pass - - # Remove the temp directory - import shutil - - if os.path.exists(temp_dir): - shutil.rmtree(temp_dir) - - return True + response: Optional[Response] = await page.goto(f"{url}{index}", timeout=5000) # example.com tends to hang connections under load. + if response is None: + print(f"{ERROR}Failed to load page {index}: No response{RESET}") + return "Error: No response" + if response.status != codes.OK: + print(f"{ERROR}Failed to load page {index}: {response.status}{RESET}") + return f"Error: {response.status}" + return await page.title() + except Exception as e: + print(f"{ERROR}Failed to load page {index}: {str(e)}{RESET}") + return f"Error: {str(e)}" -async def main(): - """Run all tests""" - try: - print(f"{INFO}Starting builtin browser tests with browser module{RESET}") + # Load all pages concurrently + load_tasks = [load_page(page_ctx, i + 1) for i, page_ctx in enumerate(browsers_manager.all_pages)] + load_results = await asyncio.gather(*load_tasks) - # # Run browser creation test - # manager, cdp_url = await test_builtin_browser_creation() - # if not manager: - # print(f"{ERROR}Browser creation failed, cannot continue tests{RESET}") - # return + # Count successes and failures + successes = sum( + 1 for r in load_results if isinstance(r, str) and not r.startswith("Error") + ) + failures = len(load_results) - successes - # # Run page operations test - # await test_page_operations(manager) + assert not failures, f"Failed to load {failures} pages" - # # Run browser status and management test - # await test_browser_status_management(manager) + total_test_time = time.time() - start_time - # # Close manager before multiple manager test - # await manager.close() + # Check memory after loading (peak memory) + browsers_manager.check_memory() - # Run multiple managers test - # await test_multiple_managers() + # Calculate key metrics + memory_per_page = browsers_manager.peak_memory / successes if successes > 0 else 0 + time_per_crawl = total_test_time / successes if successes > 0 else 0 + crawls_per_second = successes / total_test_time if total_test_time > 0 else 0 + crawls_per_minute = crawls_per_second * 60 + crawls_per_hour = crawls_per_minute * 60 - # Run performance scaling test - await test_performance_scaling() - # Run cleanup test - # await cleanup_browsers() + # Print simplified performance summary + from rich.console import Console + from rich.table import Table - # Run edge cases test - # await test_edge_cases() + console = Console() - print(f"\n{SUCCESS}All tests completed!{RESET}") + # Create a simple summary table + table = Table(title="CRAWL4AI PERFORMANCE SUMMARY") - except Exception as e: - print(f"\n{ERROR}Test failed with error: {str(e)}{RESET}") - import traceback + table.add_column("Metric", style="cyan") + table.add_column("Value", style="green") - traceback.print_exc() - finally: - # Clean up: kill any remaining builtin browsers - await cleanup_browsers() - print(f"{SUCCESS}Test cleanup complete{RESET}") + table.add_row("Total Crawls Completed", f"{successes}") + table.add_row("Total Time", f"{total_test_time:.2f} seconds") + table.add_row("Time Per Crawl", f"{time_per_crawl:.2f} seconds") + table.add_row("Crawling Speed", f"{crawls_per_second:.2f} crawls/second") + table.add_row("Projected Rate (1 minute)", f"{crawls_per_minute:.0f} crawls") + table.add_row("Projected Rate (1 hour)", f"{crawls_per_hour:.0f} crawls") + table.add_row("Peak Memory Usage", f"{browsers_manager.peak_memory:.2f} MB") + table.add_row("Memory Per Crawl", f"{memory_per_page:.2f} MB") + # Display the table + console.print(table) if __name__ == "__main__": - asyncio.run(main()) + import subprocess + + sys.exit(subprocess.call(["pytest", *sys.argv[1:], sys.argv[0]])) diff --git a/tests/browser/test_builtin_strategy.py b/tests/browser/test_builtin_strategy.py index 7c435b3de..0f8643013 100644 --- a/tests/browser/test_builtin_strategy.py +++ b/tests/browser/test_builtin_strategy.py @@ -4,13 +4,8 @@ and serve as functional tests. """ -import asyncio -import os import sys - -# Add the project root to Python path if running directly -if __name__ == "__main__": - sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..'))) +import pytest from crawl4ai.browser import BrowserManager from crawl4ai.async_configs import BrowserConfig, CrawlerRunConfig @@ -19,6 +14,7 @@ # Create a logger for clear terminal output logger = AsyncLogger(verbose=True, log_file=None) +@pytest.mark.asyncio async def test_builtin_browser(): """Test using a builtin browser that persists between sessions.""" logger.info("Testing builtin browser", tag="TEST") @@ -68,10 +64,11 @@ async def test_builtin_browser(): logger.error(f"Test failed: {str(e)}", tag="TEST") try: await manager.close() - except: + except Exception: pass return False +@pytest.mark.asyncio async def test_builtin_browser_status(): """Test getting status of the builtin browser.""" logger.info("Testing builtin browser status", tag="TEST") @@ -135,26 +132,11 @@ async def test_builtin_browser_status(): # Try to kill the builtin browser to clean up strategy2 = BuiltinBrowserStrategy(browser_config, logger) await strategy2.kill_builtin_browser() - except: + except Exception: pass return False -async def run_tests(): - """Run all tests sequentially.""" - results = [] - - results.append(await test_builtin_browser()) - results.append(await test_builtin_browser_status()) - - # Print summary - total = len(results) - passed = sum(results) - logger.info(f"Tests complete: {passed}/{total} passed", tag="SUMMARY") - - if passed == total: - logger.success("All tests passed!", tag="SUMMARY") - else: - logger.error(f"{total - passed} tests failed", tag="SUMMARY") - if __name__ == "__main__": - asyncio.run(run_tests()) + import subprocess + + sys.exit(subprocess.call(["pytest", *sys.argv[1:], sys.argv[0]])) diff --git a/tests/browser/test_cdp_strategy.py b/tests/browser/test_cdp_strategy.py index abadf42a2..54bb8bfab 100644 --- a/tests/browser/test_cdp_strategy.py +++ b/tests/browser/test_cdp_strategy.py @@ -4,13 +4,8 @@ and serve as functional tests. """ -import asyncio -import os import sys - -# Add the project root to Python path if running directly -if __name__ == "__main__": - sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..'))) +import pytest from crawl4ai.browser import BrowserManager from crawl4ai.async_configs import BrowserConfig, CrawlerRunConfig @@ -19,6 +14,7 @@ # Create a logger for clear terminal output logger = AsyncLogger(verbose=True, log_file=None) +@pytest.mark.asyncio async def test_cdp_launch_connect(): """Test launching a browser and connecting via CDP.""" logger.info("Testing launch and connect via CDP", tag="TEST") @@ -56,10 +52,11 @@ async def test_cdp_launch_connect(): logger.error(f"Test failed: {str(e)}", tag="TEST") try: await manager.close() - except: + except Exception: pass return False +@pytest.mark.asyncio async def test_cdp_with_user_data_dir(): """Test CDP browser with a user data directory.""" logger.info("Testing CDP browser with user data directory", tag="TEST") @@ -124,25 +121,26 @@ async def test_cdp_with_user_data_dir(): # Remove temporary directory import shutil shutil.rmtree(user_data_dir, ignore_errors=True) - logger.info(f"Removed temporary user data directory", tag="TEST") - + logger.info("Removed temporary user data directory", tag="TEST") + return has_test_cookie and has_test_cookie2 except Exception as e: logger.error(f"Test failed: {str(e)}", tag="TEST") try: await manager.close() - except: + except Exception: pass # Clean up temporary directory try: import shutil shutil.rmtree(user_data_dir, ignore_errors=True) - except: + except Exception: pass return False +@pytest.mark.asyncio async def test_cdp_session_management(): """Test session management with CDP browser.""" logger.info("Testing session management with CDP browser", tag="TEST") @@ -186,8 +184,8 @@ async def test_cdp_session_management(): # Kill first session await manager.kill_session(session1_id) - logger.info(f"Killed session 1", tag="TEST") - + logger.info("Killed session 1", tag="TEST") + # Verify second session still works data2 = await page2.evaluate("localStorage.getItem('session2_data')") logger.info(f"Session 2 still functional after killing session 1, data: {data2}", tag="TEST") @@ -201,27 +199,11 @@ async def test_cdp_session_management(): logger.error(f"Test failed: {str(e)}", tag="TEST") try: await manager.close() - except: + except Exception: pass return False -async def run_tests(): - """Run all tests sequentially.""" - results = [] - - # results.append(await test_cdp_launch_connect()) - # results.append(await test_cdp_with_user_data_dir()) - results.append(await test_cdp_session_management()) - - # Print summary - total = len(results) - passed = sum(results) - logger.info(f"Tests complete: {passed}/{total} passed", tag="SUMMARY") - - if passed == total: - logger.success("All tests passed!", tag="SUMMARY") - else: - logger.error(f"{total - passed} tests failed", tag="SUMMARY") - if __name__ == "__main__": - asyncio.run(run_tests()) + import subprocess + + sys.exit(subprocess.call(["pytest", *sys.argv[1:], sys.argv[0]])) diff --git a/tests/browser/test_combined.py b/tests/browser/test_combined.py deleted file mode 100644 index b5bce3cda..000000000 --- a/tests/browser/test_combined.py +++ /dev/null @@ -1,77 +0,0 @@ -"""Combined test runner for all browser module tests. - -This script runs all the browser module tests in sequence and -provides a comprehensive summary. -""" - -import asyncio -import os -import sys -import time - -# Add the project root to Python path if running directly -if __name__ == "__main__": - sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..'))) - -from crawl4ai.async_logger import AsyncLogger - -# Create a logger for clear terminal output -logger = AsyncLogger(verbose=True, log_file=None) - -async def run_test_module(module_name, header): - """Run all tests in a module and return results.""" - logger.info(f"\n{'-'*30}", tag="TEST") - logger.info(f"RUNNING: {header}", tag="TEST") - logger.info(f"{'-'*30}", tag="TEST") - - # Import the module dynamically - module = __import__(f"tests.browser.{module_name}", fromlist=["run_tests"]) - - # Track time for performance measurement - start_time = time.time() - - # Run the tests - await module.run_tests() - - # Calculate time taken - time_taken = time.time() - start_time - logger.info(f"Time taken: {time_taken:.2f} seconds", tag="TIMING") - - return time_taken - -async def main(): - """Run all test modules.""" - logger.info("STARTING COMPREHENSIVE BROWSER MODULE TESTS", tag="MAIN") - - # List of test modules to run - test_modules = [ - ("test_browser_manager", "Browser Manager Tests"), - ("test_playwright_strategy", "Playwright Strategy Tests"), - ("test_cdp_strategy", "CDP Strategy Tests"), - ("test_builtin_strategy", "Builtin Browser Strategy Tests"), - ("test_profiles", "Profile Management Tests") - ] - - # Run each test module - timings = {} - for module_name, header in test_modules: - try: - time_taken = await run_test_module(module_name, header) - timings[module_name] = time_taken - except Exception as e: - logger.error(f"Error running {module_name}: {str(e)}", tag="ERROR") - - # Print summary - logger.info("\n\nTEST SUMMARY:", tag="SUMMARY") - logger.info(f"{'-'*50}", tag="SUMMARY") - for module_name, header in test_modules: - if module_name in timings: - logger.info(f"{header}: {timings[module_name]:.2f} seconds", tag="SUMMARY") - else: - logger.error(f"{header}: FAILED TO RUN", tag="SUMMARY") - logger.info(f"{'-'*50}", tag="SUMMARY") - total_time = sum(timings.values()) - logger.info(f"Total time: {total_time:.2f} seconds", tag="SUMMARY") - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/tests/browser/test_launch_standalone.py b/tests/browser/test_launch_standalone.py index d60b12f3f..28f2ade11 100644 --- a/tests/browser/test_launch_standalone.py +++ b/tests/browser/test_launch_standalone.py @@ -1,17 +1,21 @@ +import os +import sys +import pytest from crawl4ai.browser_profiler import BrowserProfiler -import asyncio +@pytest.mark.asyncio +@pytest.mark.skip(reason="Requires user interaction to stop the browser, more work needed") +async def test_standalone_browser(): + profiler = BrowserProfiler() + cdp_url = await profiler.launch_standalone_browser( + browser_type="chromium", + user_data_dir=os.path.expanduser("~/.crawl4ai/browser_profile/test-browser-data"), + debugging_port=9222, + headless=False + ) + assert cdp_url is not None, "Failed to launch standalone browser" if __name__ == "__main__": - # Test launching a standalone browser - async def test_standalone_browser(): - profiler = BrowserProfiler() - cdp_url = await profiler.launch_standalone_browser( - browser_type="chromium", - user_data_dir="~/.crawl4ai/browser_profile/test-browser-data", - debugging_port=9222, - headless=False - ) - print(f"CDP URL: {cdp_url}") + import subprocess - asyncio.run(test_standalone_browser()) \ No newline at end of file + sys.exit(subprocess.call(["pytest", *sys.argv[1:], sys.argv[0]])) \ No newline at end of file diff --git a/tests/browser/test_parallel_crawling.py b/tests/browser/test_parallel_crawling.py index 9e72f06e3..92657ea16 100644 --- a/tests/browser/test_parallel_crawling.py +++ b/tests/browser/test_parallel_crawling.py @@ -6,14 +6,8 @@ """ import asyncio -import os -import sys import time -from typing import List - -# Add the project root to Python path if running directly -if __name__ == "__main__": - sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..'))) +import pytest from crawl4ai.browser import BrowserManager from crawl4ai.async_configs import BrowserConfig, CrawlerRunConfig @@ -22,6 +16,7 @@ # Create a logger for clear terminal output logger = AsyncLogger(verbose=True, log_file=None) +@pytest.mark.asyncio async def test_get_pages_basic(): """Test basic functionality of get_pages method.""" logger.info("Testing basic get_pages functionality", tag="TEST") @@ -53,10 +48,11 @@ async def test_get_pages_basic(): logger.error(f"Test failed: {str(e)}", tag="TEST") try: await manager.close() - except: + except Exception: pass return False +@pytest.mark.asyncio async def test_parallel_approaches_comparison(): """Compare two parallel crawling approaches: 1. Create a page for each URL on-demand (get_page + gather) @@ -148,10 +144,11 @@ async def fetch_title_approach2(page_ctx, url): logger.error(f"Test failed: {str(e)}", tag="TEST") try: await manager.close() - except: + except Exception: pass return False +@pytest.mark.asyncio async def test_multi_browser_scaling(num_browsers=3, pages_per_browser=5): """Test performance with multiple browsers and pages per browser. Compares two approaches: @@ -166,8 +163,7 @@ async def test_multi_browser_scaling(num_browsers=3, pages_per_browser=5): # Create browser managers managers = [] - base_port = 9222 - + try: # Start all browsers in parallel start_tasks = [] @@ -268,7 +264,7 @@ async def fetch_title_approach2(page_ctx, url): for manager in managers: try: await manager.close() - except: + except Exception: pass return False @@ -469,7 +465,7 @@ async def fetch_title(page_ctx, url): for manager in managers: try: await manager.close() - except: + except Exception: pass # Print summary of all configurations @@ -880,7 +876,7 @@ async def run_tests(): if configs: # Show the optimal configuration optimal = configs["optimal"] - print(f"\n🎯 Recommended configuration for production use:") + print("\n🎯 Recommended configuration for production use:") print(f" {optimal['browser_count']} browsers with distribution {optimal['distribution']}") print(f" Estimated performance: {optimal['pages_per_second']:.1f} pages/second") results.append(True) diff --git a/tests/browser/test_playwright_strategy.py b/tests/browser/test_playwright_strategy.py index 2344c9bae..c0f209b33 100644 --- a/tests/browser/test_playwright_strategy.py +++ b/tests/browser/test_playwright_strategy.py @@ -4,13 +4,8 @@ and serve as functional tests. """ -import asyncio -import os import sys - -# Add the project root to Python path if running directly -if __name__ == "__main__": - sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..'))) +import pytest from crawl4ai.browser import BrowserManager from crawl4ai.async_configs import BrowserConfig, CrawlerRunConfig @@ -19,6 +14,7 @@ # Create a logger for clear terminal output logger = AsyncLogger(verbose=True, log_file=None) +@pytest.mark.asyncio async def test_playwright_basic(): """Test basic Playwright browser functionality.""" logger.info("Testing standard Playwright browser", tag="TEST") @@ -63,10 +59,11 @@ async def test_playwright_basic(): # Ensure cleanup try: await manager.close() - except: + except Exception: pass return False +@pytest.mark.asyncio async def test_playwright_text_mode(): """Test Playwright browser in text-only mode.""" logger.info("Testing Playwright text mode", tag="TEST") @@ -106,7 +103,7 @@ async def test_playwright_text_mode(): await page.goto("https://picsum.photos/", wait_until="domcontentloaded") request = await request_info.value has_images = True - except: + except Exception: # Timeout without image requests means text mode is working has_images = False @@ -122,10 +119,11 @@ async def test_playwright_text_mode(): # Ensure cleanup try: await manager.close() - except: + except Exception: pass return False +@pytest.mark.asyncio async def test_playwright_context_reuse(): """Test context caching and reuse with identical configurations.""" logger.info("Testing context reuse with identical configurations", tag="TEST") @@ -178,10 +176,11 @@ async def test_playwright_context_reuse(): # Ensure cleanup try: await manager.close() - except: + except Exception: pass return False +@pytest.mark.asyncio async def test_playwright_session_management(): """Test session management with Playwright browser.""" logger.info("Testing session management with Playwright browser", tag="TEST") @@ -225,8 +224,8 @@ async def test_playwright_session_management(): # Kill first session await manager.kill_session(session1_id) - logger.info(f"Killed session 1", tag="TEST") - + logger.info("Killed session 1", tag="TEST") + # Verify second session still works data2 = await page2.evaluate("localStorage.getItem('playwright_session2_data')") logger.info(f"Session 2 still functional after killing session 1, data: {data2}", tag="TEST") @@ -240,28 +239,12 @@ async def test_playwright_session_management(): logger.error(f"Test failed: {str(e)}", tag="TEST") try: await manager.close() - except: + except Exception: pass return False -async def run_tests(): - """Run all tests sequentially.""" - results = [] - - results.append(await test_playwright_basic()) - results.append(await test_playwright_text_mode()) - results.append(await test_playwright_context_reuse()) - results.append(await test_playwright_session_management()) - - # Print summary - total = len(results) - passed = sum(results) - logger.info(f"Tests complete: {passed}/{total} passed", tag="SUMMARY") - - if passed == total: - logger.success("All tests passed!", tag="SUMMARY") - else: - logger.error(f"{total - passed} tests failed", tag="SUMMARY") if __name__ == "__main__": - asyncio.run(run_tests()) + import subprocess + + sys.exit(subprocess.call(["pytest", *sys.argv[1:], sys.argv[0]])) diff --git a/tests/browser/test_profiles.py b/tests/browser/test_profiles.py index 8325b561a..0a8c32c75 100644 --- a/tests/browser/test_profiles.py +++ b/tests/browser/test_profiles.py @@ -4,15 +4,11 @@ and serve as functional tests. """ -import asyncio import os import sys import uuid import shutil - -# Add the project root to Python path if running directly -if __name__ == "__main__": - sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..'))) +import pytest from crawl4ai.browser import BrowserManager, BrowserProfileManager from crawl4ai.async_configs import BrowserConfig, CrawlerRunConfig @@ -21,6 +17,7 @@ # Create a logger for clear terminal output logger = AsyncLogger(verbose=True, log_file=None) +@pytest.mark.asyncio async def test_profile_creation(): """Test creating and managing browser profiles.""" logger.info("Testing profile creation and management", tag="TEST") @@ -75,10 +72,11 @@ async def test_profile_creation(): try: if os.path.exists(profile_path): shutil.rmtree(profile_path, ignore_errors=True) - except: + except Exception: pass return False +@pytest.mark.asyncio async def test_profile_with_browser(): """Test using a profile with a browser.""" logger.info("Testing using a profile with a browser", tag="TEST") @@ -151,26 +149,11 @@ async def test_profile_with_browser(): try: if profile_path and os.path.exists(profile_path): shutil.rmtree(profile_path, ignore_errors=True) - except: + except Exception: pass return False -async def run_tests(): - """Run all tests sequentially.""" - results = [] - - results.append(await test_profile_creation()) - results.append(await test_profile_with_browser()) - - # Print summary - total = len(results) - passed = sum(results) - logger.info(f"Tests complete: {passed}/{total} passed", tag="SUMMARY") - - if passed == total: - logger.success("All tests passed!", tag="SUMMARY") - else: - logger.error(f"{total - passed} tests failed", tag="SUMMARY") - if __name__ == "__main__": - asyncio.run(run_tests()) + import subprocess + + sys.exit(subprocess.call(["pytest", *sys.argv[1:], sys.argv[0]])) \ No newline at end of file diff --git a/tests/cli/test_cli.py b/tests/cli/test_cli.py index b7416dc29..f9ddea738 100644 --- a/tests/cli/test_cli.py +++ b/tests/cli/test_cli.py @@ -1,17 +1,22 @@ -import pytest -from click.testing import CliRunner -from pathlib import Path import json -import yaml -from crawl4ai.cli import cli, load_config_file, parse_key_values -import tempfile import os +import sys +import tempfile +from pathlib import Path + import click +import pytest +import yaml +from click.testing import CliRunner, Result + +from crawl4ai.cli import cli, load_config_file, parse_key_values + @pytest.fixture -def runner(): +def runner() -> CliRunner: return CliRunner() + @pytest.fixture def temp_config_dir(): with tempfile.TemporaryDirectory() as tmpdir: @@ -21,8 +26,9 @@ def temp_config_dir(): if old_home: os.environ['HOME'] = old_home + @pytest.fixture -def sample_configs(temp_config_dir): +def sample_configs(temp_config_dir: Path) -> dict[str, str]: configs = { 'browser.yml': { 'headless': True, @@ -59,20 +65,21 @@ def sample_configs(temp_config_dir): return {name: str(temp_config_dir / name) for name in configs} class TestCLIBasics: - def test_help(self, runner): - result = runner.invoke(cli, ['--help']) + def test_help(self, runner: CliRunner): + result: Result = runner.invoke(cli, ['--help']) assert result.exit_code == 0 assert 'Crawl4AI CLI' in result.output - def test_examples(self, runner): - result = runner.invoke(cli, ['--example']) + def test_examples(self, runner: CliRunner): + result: Result = runner.invoke(cli, ['examples']) assert result.exit_code == 0 assert 'Examples' in result.output def test_missing_url(self, runner): - result = runner.invoke(cli) + result: Result = runner.invoke(cli, ['crawl']) assert result.exit_code != 0 - assert 'URL argument is required' in result.output + assert "Error: Missing argument 'URL'" in result.output + class TestConfigParsing: def test_parse_key_values_basic(self): @@ -99,35 +106,38 @@ def test_load_nonexistent_config(self): load_config_file('nonexistent.yml') class TestLLMConfig: - def test_llm_config_creation(self, temp_config_dir, runner): + def test_llm_config_creation(self, temp_config_dir: Path, runner: CliRunner): def input_simulation(inputs): return runner.invoke(cli, ['https://example.com', '-q', 'test question'], input='\n'.join(inputs)) class TestCrawlingFeatures: - def test_basic_crawl(self, runner): - result = runner.invoke(cli, ['https://example.com']) + def test_basic_crawl(self, runner: CliRunner): + result: Result = runner.invoke(cli, ['crawl', 'https://example.com']) assert result.exit_code == 0 class TestErrorHandling: - def test_invalid_config_file(self, runner): - result = runner.invoke(cli, [ + def test_invalid_config_file(self, runner: CliRunner): + result: Result = runner.invoke(cli, [ 'https://example.com', '--browser-config', 'nonexistent.yml' ]) assert result.exit_code != 0 - def test_invalid_schema(self, runner, temp_config_dir): + def test_invalid_schema(self, runner: CliRunner, temp_config_dir: Path): invalid_schema = temp_config_dir / 'invalid_schema.json' with open(invalid_schema, 'w') as f: f.write('invalid json') - - result = runner.invoke(cli, [ + + result: Result = runner.invoke(cli, [ 'https://example.com', '--schema', str(invalid_schema) ]) assert result.exit_code != 0 -if __name__ == '__main__': - pytest.main(['-v', '-s', '--tb=native', __file__]) \ No newline at end of file + +if __name__ == "__main__": + import subprocess + + sys.exit(subprocess.call(["pytest", *sys.argv[1:], sys.argv[0]])) diff --git a/tests/docker/__init__.py b/tests/docker/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/docker/common.py b/tests/docker/common.py new file mode 100644 index 000000000..59bd47319 --- /dev/null +++ b/tests/docker/common.py @@ -0,0 +1,51 @@ +from typing import List + +import pytest +from _pytest.mark import ParameterSet +from httpx import ASGITransport, AsyncClient + +from crawl4ai.docker_client import Crawl4aiDockerClient +from deploy.docker.server import app + +TEST_URLS = [ + "example.com", + "https://www.python.org", + "https://news.ycombinator.com/news", + "https://github.com/trending", +] +BASE_URL = "http://localhost:8000" + + +def async_client() -> AsyncClient: + """Create an async client for the server. + + This can be used to test the API server without running the server.""" + return AsyncClient(transport=ASGITransport(app=app), base_url=BASE_URL) + + +def docker_client() -> Crawl4aiDockerClient: + """Crawl4aiDockerClient docker client via local transport. + + This can be used to test the API server without running the server.""" + return Crawl4aiDockerClient(transport=ASGITransport(app=app), verbose=True) + + +def markdown_params(urls: List[str] = TEST_URLS) -> List[ParameterSet]: + """Parameters for markdown endpoint tests with different filters""" + tests = [] + for url in urls: + for filter_type in ["raw", "fit", "bm25", "llm"]: + for cache in ["0", "1"]: + params: dict[str, str] = {"f": filter_type, "c": cache} + if filter_type in ["bm25", "llm"]: + params["q"] = "extract main content" + + tests.append( + pytest.param( + url, + params, + id=f"{url} {filter_type}" + (" cached" if cache == "1" else ""), + ) + ) + + return tests diff --git a/tests/docker/test_config_object.py b/tests/docker/test_config_object.py index 94a30f058..b723cf7d9 100644 --- a/tests/docker/test_config_object.py +++ b/tests/docker/test_config_object.py @@ -1,4 +1,5 @@ import json +import sys from crawl4ai import ( CrawlerRunConfig, DefaultMarkdownGenerator, @@ -8,9 +9,8 @@ CacheMode ) from crawl4ai.deep_crawling import BFSDeepCrawlStrategy -from crawl4ai.deep_crawling.filters import FastFilterChain -from crawl4ai.deep_crawling.filters import FastContentTypeFilter, FastDomainFilter -from crawl4ai.deep_crawling.scorers import FastKeywordRelevanceScorer +from crawl4ai.deep_crawling.filters import FilterChain, ContentTypeFilter, DomainFilter +from crawl4ai.deep_crawling.scorers import KeywordRelevanceScorer def create_test_config() -> CrawlerRunConfig: # Set up content filtering and markdown generation @@ -35,12 +35,12 @@ def create_test_config() -> CrawlerRunConfig: extraction_strategy = JsonCssExtractionStrategy(schema=extraction_schema) # Set up deep crawling - filter_chain = FastFilterChain([ - FastContentTypeFilter(["text/html"]), - FastDomainFilter(blocked_domains=["ads.*"]) + filter_chain = FilterChain([ + ContentTypeFilter(["text/html"]), + DomainFilter(blocked_domains=["ads.*"]), ]) - url_scorer = FastKeywordRelevanceScorer( + url_scorer = KeywordRelevanceScorer( keywords=["article", "blog"], weight=1.0 ) @@ -104,10 +104,12 @@ def test_config_serialization_cycle(): # Verify deep crawl strategy configuration assert deserialized_config.deep_crawl_strategy.max_depth == 3 - assert isinstance(deserialized_config.deep_crawl_strategy.filter_chain, FastFilterChain) - assert isinstance(deserialized_config.deep_crawl_strategy.url_scorer, FastKeywordRelevanceScorer) + assert isinstance(deserialized_config.deep_crawl_strategy.filter_chain, FilterChain) + assert isinstance(deserialized_config.deep_crawl_strategy.url_scorer, KeywordRelevanceScorer) print("Serialization cycle test passed successfully!") if __name__ == "__main__": - test_config_serialization_cycle() \ No newline at end of file + import subprocess + + sys.exit(subprocess.call(["pytest", *sys.argv[1:], sys.argv[0]])) diff --git a/tests/docker/test_core.py b/tests/docker/test_core.py new file mode 100644 index 000000000..9954b55b8 --- /dev/null +++ b/tests/docker/test_core.py @@ -0,0 +1,322 @@ +import base64 +import io +import json +import os +import sys +import time +from typing import Any, Dict + +import pytest +import requests +from PIL import Image, ImageFile + +from crawl4ai import ( + CosineStrategy, + JsonCssExtractionStrategy, + LLMConfig, + LLMExtractionStrategy, +) +from crawl4ai.async_configs import BrowserConfig +from crawl4ai.async_webcrawler import CrawlerRunConfig + +from .common import async_client + + +class Crawl4AiTester: + def __init__(self): + self.client = async_client() + + async def submit_and_validate( + self, request_data: Dict[str, Any], timeout: int = 300 + ) -> Dict[str, Any]: + """Submit a crawl request and validate for the result. + + The response is validated to ensure that it is successful and contains at least one result. + + :param request_data: The request data to submit. + :type request_data: Dict[str, Any] + :param timeout: The maximum time to wait for the response, defaults to 300. + :type timeout: int, optional + :return: The response of the crawl decoded from JSON. + :rtype: Dict[str, Any] + """ + response = await self.client.post("/crawl", json=request_data) + return self.assert_valid_response(response.json()) + + async def check_health(self) -> None: + """Check the health of the service. + + Check the health of the service and wait for it to be ready. + If the service is not ready after a few retries, the test will fail.""" + max_retries = 5 + for i in range(max_retries): + try: + health = await self.client.get("/health", timeout=10) + print("Health check:", health.json()) + return + except requests.exceptions.RequestException: + if i == max_retries - 1: + assert False, f"Failed to connect after {max_retries} attempts" + + print( + f"Waiting for service to start (attempt {i + 1}/{max_retries})..." + ) + time.sleep(5) + + pytest.fail(f"Failed to connect to service after {max_retries} retries") + + def assert_valid_response(self, result: Dict[str, Any]) -> Dict[str, Any]: + """Assert that the response is valid and returns the first result. + + :param result: The response from the API + :type result: Dict[str, Any] + :return: The first result + :rtype: dict[str, Any] + """ + assert result["success"] + assert result["results"] + assert len(result["results"]) > 0 + assert result["results"][0]["url"] + assert result["results"][0]["html"] + return result["results"][0] + + +@pytest.fixture +def tester() -> Crawl4AiTester: + return Crawl4AiTester() + + +@pytest.mark.asyncio +async def test_basic_crawl(tester: Crawl4AiTester): + request = {"urls": ["https://www.nbcnews.com/business"], "priority": 10} + + await tester.submit_and_validate(request) + + +@pytest.mark.asyncio +async def test_js_execution(tester: Crawl4AiTester): + request = { + "urls": ["https://www.nbcnews.com/business"], + "priority": 8, + "js_code": [ + "const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More')); loadMoreButton && loadMoreButton.click();" + ], + "wait_for": "article.tease-card:nth-child(10)", + "crawler_params": {"headless": True}, + } + + await tester.submit_and_validate(request) + + +@pytest.mark.asyncio +async def test_css_selector(tester: Crawl4AiTester): + print("\n=== Testing CSS Selector ===") + request = { + "urls": ["https://www.nbcnews.com/business"], + "priority": 7, + "css_selector": ".wide-tease-item__description", + "crawler_params": {"headless": True}, + "extra": {"word_count_threshold": 10}, + } + + await tester.submit_and_validate(request) + + +@pytest.mark.asyncio +async def test_structured_extraction(tester: Crawl4AiTester): + schema = { + "name": "Coinbase Crypto Prices", + "baseSelector": "table > tbody > tr", + "fields": [ + { + "name": "crypto", + "selector": "td:nth-child(1) h2", + "type": "text", + }, + { + "name": "symbol", + "selector": "td:nth-child(1) p", + "type": "text", + }, + { + "name": "price", + "selector": "td:nth-child(2)", + "type": "text", + }, + ], + } + crawler_config: CrawlerRunConfig = CrawlerRunConfig( + extraction_strategy=JsonCssExtractionStrategy(schema), + ) + request = { + "urls": ["https://www.coinbase.com/explore"], + "priority": 9, + "crawler_config": crawler_config.dump(), + } + + result = await tester.submit_and_validate(request) + + extracted = json.loads(result["extracted_content"]) + print(f"Extracted {len(extracted)} items") + print("Sample item:", json.dumps(extracted[0], indent=2)) + + assert len(extracted) > 0 + + +@pytest.mark.asyncio +async def test_llm_extraction(tester: Crawl4AiTester): + if not os.getenv("OPENAI_API_KEY"): + pytest.skip("OPENAI_API_KEY not set") + + schema = { + "type": "object", + "properties": { + "model_name": { + "type": "string", + "description": "Name of the OpenAI model.", + }, + "input_fee": { + "type": "string", + "description": "Fee for input token for the OpenAI model.", + }, + "output_fee": { + "type": "string", + "description": "Fee for output token for the OpenAI model.", + }, + }, + "required": ["model_name", "input_fee", "output_fee"], + } + + crawler_config: CrawlerRunConfig = CrawlerRunConfig( + extraction_strategy=LLMExtractionStrategy( + schema=schema, + llm_config=LLMConfig( + provider="openai/gpt-4o-mini", + api_token=os.getenv("OPENAI_API_KEY"), + ), + extraction_type="schema", + instruction="From the crawled content, extract all mentioned model names along with their fees for input and output tokens.", + ), + word_count_threshold=1, + ) + request = { + "urls": ["https://openai.com/api/pricing"], + "priority": 8, + "crawler_config": crawler_config.dump(), + } + + result = await tester.submit_and_validate(request) + extracted = json.loads(result["extracted_content"]) + print(f"Extracted {len(extracted)} model pricing entries") + print("Sample entry:", json.dumps(extracted[0], indent=2)) + assert result["success"] + assert len(extracted) > 0, "No model pricing entries found" + + +@pytest.mark.asyncio +@pytest.mark.timeout(30) +async def test_llm_with_ollama(tester: Crawl4AiTester): + print("\n=== Testing LLM with Ollama ===") + schema = { + "type": "object", + "properties": { + "article_title": { + "type": "string", + "description": "The main title of the news article", + }, + "summary": { + "type": "string", + "description": "A brief summary of the article content", + }, + "main_topics": { + "type": "array", + "items": {"type": "string"}, + "description": "Main topics or themes discussed in the article", + }, + }, + } + + crawler_config: CrawlerRunConfig = CrawlerRunConfig( + extraction_strategy=LLMExtractionStrategy( + schema=schema, + llm_config=LLMConfig( + provider="ollama/llama2", + ), + extraction_type="schema", + instruction="Extract the main article information including title, summary, and main topics.", + ), + word_count_threshold=1, + ) + + request = { + "urls": ["https://www.nbcnews.com/business"], + "priority": 8, + "crawler_config": crawler_config.dump(), + } + + result = await tester.submit_and_validate(request) + extracted = json.loads(result["extracted_content"]) + print("Extracted content:", json.dumps(extracted, indent=2)) + assert result["success"] + assert len(extracted) > 0, "No content extracted" + + +@pytest.mark.asyncio +async def test_cosine_extraction(tester: Crawl4AiTester): + print("\n=== Testing Cosine Extraction ===") + crawler_config: CrawlerRunConfig = CrawlerRunConfig( + extraction_strategy=CosineStrategy( + semantic_filter="business finance economy", + word_count_threshold=10, + max_dist=0.2, + top_k=3, + ), + word_count_threshold=1, + verbose=True, + ) + + request = { + "urls": ["https://www.nbcnews.com/business"], + "priority": 8, + "crawler_config": crawler_config.dump(), + } + + result = await tester.submit_and_validate(request) + extracted = json.loads(result["extracted_content"]) + print(f"Extracted {len(extracted)} text clusters") + print("First cluster tags:", extracted[0]["tags"]) + assert result["success"] + assert len(extracted) > 0, "No clusters found" + + +@pytest.mark.asyncio +async def test_screenshot(tester: Crawl4AiTester): + crawler_config: CrawlerRunConfig = CrawlerRunConfig( + screenshot=True, + word_count_threshold=1, + verbose=True, + ) + + browser_config: BrowserConfig = BrowserConfig(headless=True) + + request = { + "urls": ["https://www.nbcnews.com/business"], + "priority": 5, + "screenshot": True, + "crawler_config": crawler_config.dump(), + "browser_config": browser_config.dump(), + } + + result = await tester.submit_and_validate(request) + assert result.get("screenshot") + image: ImageFile.ImageFile = Image.open( + io.BytesIO(base64.b64decode(result["screenshot"])) + ) + + assert image.format == "BMP" + + +if __name__ == "__main__": + import subprocess + + sys.exit(subprocess.call(["pytest", *sys.argv[1:], sys.argv[0]])) diff --git a/tests/docker/test_crawl_task.py b/tests/docker/test_crawl_task.py new file mode 100644 index 000000000..596f1d2aa --- /dev/null +++ b/tests/docker/test_crawl_task.py @@ -0,0 +1,275 @@ +import asyncio +import json +import os +import sys +import time +from typing import Any, Dict + +import pytest +from httpx import AsyncClient, Response, codes + +from crawl4ai.async_configs import CrawlerRunConfig + +from .common import async_client + + +class NBCNewsAPITest: + def __init__(self): + self.client: AsyncClient = async_client() + + async def submit_crawl(self, request_data: Dict[str, Any]) -> str: + if "crawler_config" not in request_data: + request_data["crawler_config"] = CrawlerRunConfig(stream=True).dump() + + response: Response = await self.client.post("/crawl/stream", json=request_data) + assert response.status_code == codes.OK, f"Error: {response.text}" + result = response.json() + assert "task_id" in result + + return result["task_id"] + + async def get_task_status(self, task_id: str) -> Dict[str, Any]: + response: Response = await self.client.get(f"/task/{task_id}") + assert response.status_code == codes.OK + result = response.json() + assert "status" in result + return result + + async def wait_for_task( + self, task_id: str, timeout: int = 300, poll_interval: int = 2 + ) -> Dict[str, Any]: + start_time = time.time() + while True: + if time.time() - start_time > timeout: + raise TimeoutError( + f"Task {task_id} did not complete within {timeout} seconds" + ) + + status = await self.get_task_status(task_id) + if status["status"] in ["completed", "failed"]: + return status + + await asyncio.sleep(poll_interval) + + async def check_health(self) -> Dict[str, Any]: + response: Response = await self.client.get("/health") + assert response.status_code == codes.OK + return response.json() + + +@pytest.fixture +def api() -> NBCNewsAPITest: + return NBCNewsAPITest() + + +@pytest.mark.asyncio +@pytest.mark.skip("Crawl with task_id not implemented yet") +async def test_basic_crawl(api: NBCNewsAPITest): + print("\n=== Testing Basic Crawl ===") + request = {"urls": ["https://www.nbcnews.com/business"], "priority": 10} + task_id = await api.submit_crawl(request) + result = await api.wait_for_task(task_id) + print(f"Basic crawl result length: {len(result['result']['markdown'])}") + assert result["status"] == "completed" + assert "result" in result + assert result["result"]["success"] + + +@pytest.mark.asyncio +@pytest.mark.skip("Crawl with task_id not implemented yet") +async def test_js_execution(api: NBCNewsAPITest): + print("\n=== Testing JS Execution ===") + request = { + "urls": ["https://www.nbcnews.com/business"], + "priority": 8, + "js_code": [ + "const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More')); loadMoreButton && loadMoreButton.click();" + ], + "wait_for": "article.tease-card:nth-child(10)", + "crawler_params": {"headless": True}, + } + task_id = await api.submit_crawl(request) + result = await api.wait_for_task(task_id) + print(f"JS execution result length: {len(result['result']['markdown'])}") + assert result["status"] == "completed" + assert result["result"]["success"] + + +@pytest.mark.asyncio +@pytest.mark.skip("Crawl with task_id not implemented yet") +async def test_css_selector(api: NBCNewsAPITest): + print("\n=== Testing CSS Selector ===") + request = { + "urls": ["https://www.nbcnews.com/business"], + "priority": 7, + "css_selector": ".wide-tease-item__description", + } + task_id = await api.submit_crawl(request) + result = await api.wait_for_task(task_id) + print(f"CSS selector result length: {len(result['result']['markdown'])}") + assert result["status"] == "completed" + assert result["result"]["success"] + + +@pytest.mark.asyncio +@pytest.mark.skip("Crawl with task_id not implemented yet") +async def test_structured_extraction(api: NBCNewsAPITest): + print("\n=== Testing Structured Extraction ===") + schema = { + "name": "NBC News Articles", + "baseSelector": "article.tease-card", + "fields": [ + {"name": "title", "selector": "h2", "type": "text"}, + { + "name": "description", + "selector": ".tease-card__description", + "type": "text", + }, + { + "name": "link", + "selector": "a", + "type": "attribute", + "attribute": "href", + }, + ], + } + + request = { + "urls": ["https://www.nbcnews.com/business"], + "priority": 9, + "extraction_config": {"type": "json_css", "params": {"schema": schema}}, + } + task_id = await api.submit_crawl(request) + result = await api.wait_for_task(task_id) + extracted = json.loads(result["result"]["extracted_content"]) + print(f"Extracted {len(extracted)} articles") + assert result["status"] == "completed" + assert result["result"]["success"] + assert len(extracted) > 0 + + +@pytest.mark.asyncio +@pytest.mark.skip("Crawl with task_id not implemented yet") +async def test_batch_crawl(api: NBCNewsAPITest): + print("\n=== Testing Batch Crawl ===") + request = { + "urls": [ + "https://www.nbcnews.com/business", + "https://www.nbcnews.com/business/consumer", + "https://www.nbcnews.com/business/economy", + ], + "priority": 6, + "crawler_params": {"headless": True}, + } + task_id = await api.submit_crawl(request) + result = await api.wait_for_task(task_id) + print(f"Batch crawl completed, got {len(result['results'])} results") + assert result["status"] == "completed" + assert "results" in result + assert len(result["results"]) == 3 + + +@pytest.mark.asyncio +@pytest.mark.skip("Crawl with task_id not implemented yet") +async def test_llm_extraction(api: NBCNewsAPITest): + print("\n=== Testing LLM Extraction with Ollama ===") + schema = { + "type": "object", + "properties": { + "article_title": { + "type": "string", + "description": "The main title of the news article", + }, + "summary": { + "type": "string", + "description": "A brief summary of the article content", + }, + "main_topics": { + "type": "array", + "items": {"type": "string"}, + "description": "Main topics or themes discussed in the article", + }, + }, + "required": ["article_title", "summary", "main_topics"], + } + + request = { + "urls": ["https://www.nbcnews.com/business"], + "priority": 8, + "extraction_config": { + "type": "llm", + "params": { + "provider": "openai/gpt-4o-mini", + "api_key": os.getenv("OLLAMA_API_KEY"), + "schema": schema, + "extraction_type": "schema", + "instruction": """Extract the main article information including title, a brief summary, and main topics discussed. + Focus on the primary business news article on the page.""", + }, + }, + "crawler_params": {"headless": True, "word_count_threshold": 1}, + } + + task_id = await api.submit_crawl(request) + result = await api.wait_for_task(task_id) + + if result["status"] == "completed": + extracted = json.loads(result["result"]["extracted_content"]) + print("Extracted article analysis:") + print(json.dumps(extracted, indent=2)) + + assert result["status"] == "completed" + assert result["result"]["success"] + + +@pytest.mark.asyncio +@pytest.mark.skip("Crawl with task_id not implemented yet") +async def test_screenshot(api: NBCNewsAPITest): + print("\n=== Testing Screenshot ===") + request = { + "urls": ["https://www.nbcnews.com/business"], + "priority": 5, + "screenshot": True, + "crawler_params": {"headless": True}, + } + task_id = await api.submit_crawl(request) + result = await api.wait_for_task(task_id) + print("Screenshot captured:", bool(result["result"]["screenshot"])) + assert result["status"] == "completed" + assert result["result"]["success"] + assert result["result"]["screenshot"] is not None + + +@pytest.mark.asyncio +@pytest.mark.skip("Crawl with task_id not implemented yet") +async def test_priority_handling(api: NBCNewsAPITest): + print("\n=== Testing Priority Handling ===") + # Submit low priority task first + low_priority = { + "urls": ["https://www.nbcnews.com/business"], + "priority": 1, + "crawler_params": {"headless": True}, + } + low_task_id = await api.submit_crawl(low_priority) + + # Submit high priority task + high_priority = { + "urls": ["https://www.nbcnews.com/business/consumer"], + "priority": 10, + "crawler_params": {"headless": True}, + } + high_task_id = await api.submit_crawl(high_priority) + + # Get both results + high_result = await api.wait_for_task(high_task_id) + low_result = await api.wait_for_task(low_task_id) + + print("Both tasks completed") + assert high_result["status"] == "completed" + assert low_result["status"] == "completed" + + +if __name__ == "__main__": + import subprocess + + sys.exit(subprocess.call(["pytest", *sys.argv[1:], sys.argv[0]])) diff --git a/tests/docker/test_docker.py b/tests/docker/test_docker.py index cf95671ea..f64355711 100644 --- a/tests/docker/test_docker.py +++ b/tests/docker/test_docker.py @@ -1,58 +1,25 @@ -import requests -import time -import httpx -import asyncio -from typing import Dict, Any +import sys + +from httpx import codes +import pytest + from crawl4ai import ( BrowserConfig, CrawlerRunConfig, DefaultMarkdownGenerator, PruningContentFilter, JsonCssExtractionStrategy, LLMContentFilter, CacheMode ) -from crawl4ai import LLMConfig -from crawl4ai.docker_client import Crawl4aiDockerClient - -class Crawl4AiTester: - def __init__(self, base_url: str = "http://localhost:11235"): - self.base_url = base_url - - def submit_and_wait( - self, request_data: Dict[str, Any], timeout: int = 300 - ) -> Dict[str, Any]: - # Submit crawl job - response = requests.post(f"{self.base_url}/crawl", json=request_data) - task_id = response.json()["task_id"] - print(f"Task ID: {task_id}") - - # Poll for result - start_time = time.time() - while True: - if time.time() - start_time > timeout: - raise TimeoutError( - f"Task {task_id} did not complete within {timeout} seconds" - ) +from crawl4ai.async_configs import LLMConfig - result = requests.get(f"{self.base_url}/task/{task_id}") - status = result.json() +from .common import async_client, docker_client - if status["status"] == "failed": - print("Task failed:", status.get("error")) - raise Exception(f"Task failed: {status.get('error')}") - if status["status"] == "completed": - return status +@pytest.fixture +def browser_config() -> BrowserConfig: + return BrowserConfig(headless=True, viewport_width=1200, viewport_height=800) - time.sleep(2) -async def test_direct_api(): - """Test direct API endpoints without using the client SDK""" - print("\n=== Testing Direct API Calls ===") - - # Test 1: Basic crawl with content filtering - browser_config = BrowserConfig( - headless=True, - viewport_width=1200, - viewport_height=800 - ) - +@pytest.mark.asyncio +async def test_direct_filtering(browser_config: BrowserConfig): + """Direct request with content filtering.""" crawler_config = CrawlerRunConfig( cache_mode=CacheMode.BYPASS, markdown_generator=DefaultMarkdownGenerator( @@ -72,22 +39,25 @@ async def test_direct_api(): } # Make direct API call - async with httpx.AsyncClient() as client: + async with async_client() as client: response = await client.post( - "http://localhost:8000/crawl", + "/crawl", json=request_data, timeout=300 ) - assert response.status_code == 200 + assert response.status_code == codes.OK result = response.json() - print("Basic crawl result:", result["success"]) + assert result["success"] + - # Test 2: Structured extraction with JSON CSS +@pytest.mark.asyncio +async def test_direct_structured_extraction(browser_config: BrowserConfig): + """Direct request using structured extraction with JSON CSS.""" schema = { - "baseSelector": "article.post", + "baseSelector": "body > div", "fields": [ {"name": "title", "selector": "h1", "type": "text"}, - {"name": "content", "selector": ".content", "type": "html"} + {"name": "content", "selector": "p", "type": "html"} ] } @@ -96,30 +66,50 @@ async def test_direct_api(): extraction_strategy=JsonCssExtractionStrategy(schema=schema) ) - request_data["crawler_config"] = crawler_config.dump() + request_data = { + "urls": ["https://example.com"], + "browser_config": browser_config.dump(), + "crawler_config": crawler_config.dump() + } - async with httpx.AsyncClient() as client: + async with async_client() as client: response = await client.post( - "http://localhost:8000/crawl", + "/crawl", json=request_data ) - assert response.status_code == 200 + assert response.status_code == codes.OK result = response.json() - print("Structured extraction result:", result["success"]) + assert result["success"] + assert result["results"] + assert len(result["results"]) == 1 + assert "extracted_content" in result["results"][0] + assert ( + result["results"][0]["extracted_content"] + == """[ + { + "title": "Example Domain", + "content": "

This domain is for use in illustrative examples in documents. You may use this\\n domain in literature without prior coordination or asking for permission.

" + } +]""" + ) + + +@pytest.mark.asyncio +async def test_direct_schema(browser_config: BrowserConfig): + """Get the schema.""" + async with async_client() as client: + response = await client.get("/schema") + assert response.status_code == codes.OK + schemas = response.json() + assert schemas + assert len(schemas.keys()) == 2 + print("Retrieved schemas for:", list(schemas.keys())) - # Test 3: Get schema - # async with httpx.AsyncClient() as client: - # response = await client.get("http://localhost:8000/schema") - # assert response.status_code == 200 - # schemas = response.json() - # print("Retrieved schemas for:", list(schemas.keys())) -async def test_with_client(): +@pytest.mark.asyncio +async def test_with_client_basic(): """Test using the Crawl4AI Docker client SDK""" - print("\n=== Testing Client SDK ===") - - async with Crawl4aiDockerClient(verbose=True) as client: - # Test 1: Basic crawl + async with docker_client() as client: browser_config = BrowserConfig(headless=True) crawler_config = CrawlerRunConfig( cache_mode=CacheMode.BYPASS, @@ -128,17 +118,23 @@ async def test_with_client(): threshold=0.48, threshold_type="fixed" ) - ) + ), + stream=False, ) + await client.authenticate("test@example.com") result = await client.crawl( urls=["https://example.com"], browser_config=browser_config, crawler_config=crawler_config ) - print("Client SDK basic crawl:", result.success) + assert result.success + - # Test 2: LLM extraction with streaming +@pytest.mark.asyncio +async def test_with_client_llm_streaming(): + async with docker_client() as client: + browser_config = BrowserConfig(headless=True) crawler_config = CrawlerRunConfig( cache_mode=CacheMode.BYPASS, markdown_generator=DefaultMarkdownGenerator( @@ -150,26 +146,24 @@ async def test_with_client(): stream=True ) + await client.authenticate("test@example.com") async for result in await client.crawl( urls=["https://example.com"], browser_config=browser_config, crawler_config=crawler_config ): - print(f"Streaming result for: {result.url}") + assert result.success, f"Stream failed with: {result.error_message}" - # # Test 3: Get schema - # schemas = await client.get_schema() - # print("Retrieved client schemas for:", list(schemas.keys())) -async def main(): - """Run all tests""" - # Test direct API - print("Testing direct API calls...") - await test_direct_api() +@pytest.mark.asyncio +async def test_with_client_get_schema(): + async with docker_client() as client: + await client.authenticate("test@example.com") + schemas = await client.get_schema() + print("Retrieved client schemas for:", list(schemas.keys())) - # Test client SDK - print("\nTesting client SDK...") - await test_with_client() if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file + import subprocess + + sys.exit(subprocess.call(["pytest", *sys.argv[1:], sys.argv[0]])) diff --git a/tests/docker/test_dockerclient.py b/tests/docker/test_dockerclient.py index cba6c4c9c..9b7eaa691 100644 --- a/tests/docker/test_dockerclient.py +++ b/tests/docker/test_dockerclient.py @@ -1,34 +1,50 @@ -import asyncio -from crawl4ai.docker_client import Crawl4aiDockerClient -from crawl4ai import ( - BrowserConfig, - CrawlerRunConfig -) - -async def main(): - async with Crawl4aiDockerClient(base_url="http://localhost:8000", verbose=True) as client: +import sys + +import pytest + +from crawl4ai import BrowserConfig, CrawlerRunConfig + +from .common import docker_client + + +@pytest.mark.asyncio +async def test_non_streaming(): + async with docker_client() as client: await client.authenticate("test@example.com") - + # Non-streaming crawl results = await client.crawl( ["https://example.com", "https://python.org"], browser_config=BrowserConfig(headless=True), crawler_config=CrawlerRunConfig() ) - print(f"Non-streaming results: {results}") - + assert results + for result in results: + assert result.success + print(f"Non-streamed result: {result}") + + +@pytest.mark.asyncio +async def test_streaming(): + async with docker_client() as client: # Streaming crawl crawler_config = CrawlerRunConfig(stream=True) + await client.authenticate("user@example.com") async for result in await client.crawl( ["https://example.com", "https://python.org"], browser_config=BrowserConfig(headless=True), crawler_config=crawler_config ): + assert result.success print(f"Streamed result: {result}") - + # Get schema schema = await client.get_schema() + assert schema print(f"Schema: {schema}") + if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file + import subprocess + + sys.exit(subprocess.call(["pytest", *sys.argv[1:], sys.argv[0]])) diff --git a/tests/docker/test_serialization.py b/tests/docker/test_serialization.py index 6ce800059..93faa3e98 100644 --- a/tests/docker/test_serialization.py +++ b/tests/docker/test_serialization.py @@ -1,110 +1,24 @@ -import inspect -from typing import Any, Dict -from enum import Enum - -from crawl4ai import LLMConfig - -def to_serializable_dict(obj: Any) -> Dict: - """ - Recursively convert an object to a serializable dictionary using {type, params} structure - for complex objects. - """ - if obj is None: - return None - - # Handle basic types - if isinstance(obj, (str, int, float, bool)): - return obj - - # Handle Enum - if isinstance(obj, Enum): - return { - "type": obj.__class__.__name__, - "params": obj.value - } - - # Handle datetime objects - if hasattr(obj, 'isoformat'): - return obj.isoformat() - - # Handle lists, tuples, and sets - if isinstance(obj, (list, tuple, set)): - return [to_serializable_dict(item) for item in obj] - - # Handle dictionaries - preserve them as-is - if isinstance(obj, dict): - return { - "type": "dict", # Mark as plain dictionary - "value": {str(k): to_serializable_dict(v) for k, v in obj.items()} - } - - # Handle class instances - if hasattr(obj, '__class__'): - # Get constructor signature - sig = inspect.signature(obj.__class__.__init__) - params = sig.parameters - - # Get current values - current_values = {} - for name, param in params.items(): - if name == 'self': - continue - - value = getattr(obj, name, param.default) - - # Only include if different from default, considering empty values - if not (is_empty_value(value) and is_empty_value(param.default)): - if value != param.default: - current_values[name] = to_serializable_dict(value) - - return { - "type": obj.__class__.__name__, - "params": current_values - } - - return str(obj) - -def from_serializable_dict(data: Any) -> Any: - """ - Recursively convert a serializable dictionary back to an object instance. - """ - if data is None: - return None - - # Handle basic types - if isinstance(data, (str, int, float, bool)): - return data - - # Handle typed data - if isinstance(data, dict) and "type" in data: - # Handle plain dictionaries - if data["type"] == "dict": - return {k: from_serializable_dict(v) for k, v in data["value"].items()} - - # Import from crawl4ai for class instances - import crawl4ai - cls = getattr(crawl4ai, data["type"]) - - # Handle Enum - if issubclass(cls, Enum): - return cls(data["params"]) - - # Handle class instances - constructor_args = { - k: from_serializable_dict(v) for k, v in data["params"].items() - } - return cls(**constructor_args) - - # Handle lists - if isinstance(data, list): - return [from_serializable_dict(item) for item in data] - - # Handle raw dictionaries (legacy support) - if isinstance(data, dict): - return {k: from_serializable_dict(v) for k, v in data.items()} - - return data - +import sys +from typing import Any, List + +import pytest +from _pytest.mark import ParameterSet + +from crawl4ai import ( + BM25ContentFilter, + CacheMode, + CrawlerRunConfig, + DefaultMarkdownGenerator, + JsonCssExtractionStrategy, + LLMContentFilter, + LXMLWebScrapingStrategy, + PruningContentFilter, + RegexChunking, + WebScrapingStrategy, +) +from crawl4ai.async_configs import LLMConfig, from_serializable, to_serializable_dict + + def is_empty_value(value: Any) -> bool: """Check if a value is effectively empty/null.""" if value is None: @@ -113,90 +27,27 @@ def is_empty_value(value: Any) -> bool: return True return False -# if __name__ == "__main__": -# from crawl4ai import ( -# CrawlerRunConfig, CacheMode, DefaultMarkdownGenerator, -# PruningContentFilter, BM25ContentFilter, LLMContentFilter, -# JsonCssExtractionStrategy, CosineStrategy, RegexChunking, -# WebScrapingStrategy, LXMLWebScrapingStrategy -# ) - -# # Test Case 1: BM25 content filtering through markdown generator -# config1 = CrawlerRunConfig( -# cache_mode=CacheMode.BYPASS, -# markdown_generator=DefaultMarkdownGenerator( -# content_filter=BM25ContentFilter( -# user_query="technology articles", -# bm25_threshold=1.2, -# language="english" -# ) -# ), -# chunking_strategy=RegexChunking(patterns=[r"\n\n", r"\.\s+"]), -# excluded_tags=["nav", "footer", "aside"], -# remove_overlay_elements=True -# ) - -# # Serialize -# serialized = to_serializable_dict(config1) -# print("\nSerialized Config:") -# print(serialized) - -# # Example output structure would now look like: -# """ -# { -# "type": "CrawlerRunConfig", -# "params": { -# "cache_mode": { -# "type": "CacheMode", -# "params": "bypass" -# }, -# "markdown_generator": { -# "type": "DefaultMarkdownGenerator", -# "params": { -# "content_filter": { -# "type": "BM25ContentFilter", -# "params": { -# "user_query": "technology articles", -# "bm25_threshold": 1.2, -# "language": "english" -# } -# } -# } -# } -# } -# } -# """ - -# # Deserialize -# deserialized = from_serializable_dict(serialized) -# print("\nDeserialized Config:") -# print(to_serializable_dict(deserialized)) - -# # Verify they match -# assert to_serializable_dict(config1) == to_serializable_dict(deserialized) -# print("\nVerification passed: Configuration matches after serialization/deserialization!") - -if __name__ == "__main__": - from crawl4ai import ( - CrawlerRunConfig, CacheMode, DefaultMarkdownGenerator, - PruningContentFilter, BM25ContentFilter, LLMContentFilter, - JsonCssExtractionStrategy, RegexChunking, - WebScrapingStrategy, LXMLWebScrapingStrategy - ) +def config_params() -> List[ParameterSet]: # Test Case 1: BM25 content filtering through markdown generator - config1 = CrawlerRunConfig( - cache_mode=CacheMode.BYPASS, - markdown_generator=DefaultMarkdownGenerator( - content_filter=BM25ContentFilter( - user_query="technology articles", - bm25_threshold=1.2, - language="english" - ) - ), - chunking_strategy=RegexChunking(patterns=[r"\n\n", r"\.\s+"]), - excluded_tags=["nav", "footer", "aside"], - remove_overlay_elements=True + params: List[ParameterSet] = [] + params.append( + pytest.param( + CrawlerRunConfig( + cache_mode=CacheMode.BYPASS, + markdown_generator=DefaultMarkdownGenerator( + content_filter=BM25ContentFilter( + user_query="technology articles", + bm25_threshold=1.2, + language="english", + ) + ), + chunking_strategy=RegexChunking(patterns=[r"\n\n", r"\.\s+"]), + excluded_tags=["nav", "footer", "aside"], + remove_overlay_elements=True, + ), + id="BM25 Content Filter", + ) ) # Test Case 2: LLM-based extraction with pruning filter @@ -204,52 +55,64 @@ def is_empty_value(value: Any) -> bool: "baseSelector": "article.post", "fields": [ {"name": "title", "selector": "h1", "type": "text"}, - {"name": "content", "selector": ".content", "type": "html"} - ] + {"name": "content", "selector": ".content", "type": "html"}, + ], } - config2 = CrawlerRunConfig( - extraction_strategy=JsonCssExtractionStrategy(schema=schema), - markdown_generator=DefaultMarkdownGenerator( - content_filter=PruningContentFilter( - threshold=0.48, - threshold_type="fixed", - min_word_threshold=0 + params.append( + pytest.param( + CrawlerRunConfig( + extraction_strategy=JsonCssExtractionStrategy(schema=schema), + markdown_generator=DefaultMarkdownGenerator( + content_filter=PruningContentFilter( + threshold=0.48, threshold_type="fixed", min_word_threshold=0 + ), + options={"ignore_links": True}, + ), + scraping_strategy=LXMLWebScrapingStrategy(), ), - options={"ignore_links": True} - ), - scraping_strategy=LXMLWebScrapingStrategy() + id="LLM Pruning Filter", + ) ) # Test Case 3:LLM content filter - config3 = CrawlerRunConfig( - markdown_generator=DefaultMarkdownGenerator( - content_filter=LLMContentFilter( - llm_config = LLMConfig(provider="openai/gpt-4"), - instruction="Extract key technical concepts", - chunk_token_threshold=2000, - overlap_rate=0.1 + params.append( + pytest.param( + CrawlerRunConfig( + markdown_generator=DefaultMarkdownGenerator( + content_filter=LLMContentFilter( + llm_config=LLMConfig(provider="openai/gpt-4"), + instruction="Extract key technical concepts", + chunk_token_threshold=2000, + overlap_rate=0.1, + ), + options={"ignore_images": True}, + ), + scraping_strategy=WebScrapingStrategy(), ), - options={"ignore_images": True} - ), - scraping_strategy=WebScrapingStrategy() + id="LLM Content Filter", + ) ) - # Test all configurations - test_configs = [config1, config2, config3] - - for i, config in enumerate(test_configs, 1): - print(f"\nTesting Configuration {i}:") - - # Serialize - serialized = to_serializable_dict(config) - print(f"\nSerialized Config {i}:") - print(serialized) - - # Deserialize - deserialized = from_serializable_dict(serialized) - print(f"\nDeserialized Config {i}:") - print(to_serializable_dict(deserialized)) # Convert back to dict for comparison - - # Verify they match - assert to_serializable_dict(config) == to_serializable_dict(deserialized) - print(f"\nVerification passed: Configuration {i} matches after serialization/deserialization!") \ No newline at end of file + return params + + +@pytest.mark.parametrize("config", config_params()) +def test_serialization(config: CrawlerRunConfig) -> None: + # Serialize + serialized = to_serializable_dict(config) + print("\nSerialized Config:") + print(serialized) + + # Deserialize + deserialized = from_serializable(serialized) + print("\nDeserialized Config:") + print(to_serializable_dict(deserialized)) # Convert back to dict for comparison + + # Verify they match + assert to_serializable_dict(config) == to_serializable_dict(deserialized) + + +if __name__ == "__main__": + import subprocess + + sys.exit(subprocess.call(["pytest", *sys.argv[1:], sys.argv[0]])) diff --git a/tests/docker/test_server.py b/tests/docker/test_server.py index 7bb0195bc..59dfd8130 100644 --- a/tests/docker/test_server.py +++ b/tests/docker/test_server.py @@ -1,146 +1,136 @@ import asyncio import json -from typing import Optional +import sys +from typing import Optional, Union from urllib.parse import quote -async def test_endpoint( - endpoint: str, - url: str, - params: Optional[dict] = None, - expected_status: int = 200 -) -> None: +import aiohttp +import pytest +from httpx import Response, codes + +from .common import TEST_URLS, async_client, markdown_params + +EndpointResponse = Optional[Union[dict, str]] + + +async def endpoint( + endpoint: str, url: str, params: Optional[dict] = None, expected_status: int = codes.OK +) -> EndpointResponse: """Test an endpoint and print results""" - import aiohttp - params = params or {} param_str = "&".join(f"{k}={v}" for k, v in params.items()) - full_url = f"http://localhost:8000/{endpoint}/{quote(url)}" + path = f"/{endpoint}/{quote(url)}" if param_str: - full_url += f"?{param_str}" - - print(f"\nTesting: {full_url}") - - try: - async with aiohttp.ClientSession() as session: - async with session.get(full_url) as response: - status = response.status - try: - data = await response.json() - except: - data = await response.text() - - print(f"Status: {status} (Expected: {expected_status})") - if isinstance(data, dict): - print(f"Response: {json.dumps(data, indent=2)}") - else: - print(f"Response: {data[:500]}...") # First 500 chars - assert status == expected_status - return data - except Exception as e: - print(f"Error: {str(e)}") - return None - -async def test_llm_task_completion(task_id: str) -> None: + path += f"?{param_str}" + + print(f"\nTesting: {path}") + + async with async_client() as session: + response: Response = await session.get(path) + content_type: str = response.headers.get( + aiohttp.hdrs.CONTENT_TYPE, "" + ).lower() + data: Union[dict, str] = ( + response.json() if content_type == "application/json" else response.text + ) + + print(f"Status: {response.status_code} (Expected: {expected_status})") + if isinstance(data, dict): + print(f"Response: {json.dumps(data, indent=2)}") + else: + print(f"Response: {data[:500]}...") # First 500 chars + assert response.status_code == expected_status + return data + + +async def llm_task_completion(task_id: str) -> Optional[dict]: """Poll task until completion""" for _ in range(10): # Try 10 times - result = await test_endpoint("llm", task_id) + result: EndpointResponse = await endpoint("llm", task_id) + assert result, "Failed to process endpoint request" + assert isinstance(result, dict), "Expected dict response" + if result and result.get("status") in ["completed", "failed"]: return result print("Task still processing, waiting 5 seconds...") await asyncio.sleep(5) print("Task timed out") + return None -async def run_tests(): - print("Starting API Tests...") - - # Test URLs - urls = [ - "example.com", - "https://www.python.org", - "https://news.ycombinator.com/news", - "https://github.com/trending" - ] - - print("\n=== Testing Markdown Endpoint ===") - for url in[] : #urls: - # Test different filter types - for filter_type in ["raw", "fit", "bm25", "llm"]: - params = {"f": filter_type} - if filter_type in ["bm25", "llm"]: - params["q"] = "extract main content" - - # Test with and without cache - for cache in ["0", "1"]: - params["c"] = cache - await test_endpoint("md", url, params) - await asyncio.sleep(1) # Be nice to the server - - print("\n=== Testing LLM Endpoint ===") - for url in []: # urls: - # Test basic extraction - result = await test_endpoint( - "llm", - url, - {"q": "Extract title and main content"} - ) - if result and "task_id" in result: - print("\nChecking task completion...") - await test_llm_task_completion(result["task_id"]) - - # Test with schema - schema = { - "type": "object", - "properties": { - "title": {"type": "string"}, - "content": {"type": "string"}, - "links": {"type": "array", "items": {"type": "string"}} - } - } - result = await test_endpoint( - "llm", - url, - { - "q": "Extract content with links", - "s": json.dumps(schema), - "c": "1" # Test with cache - } - ) - if result and "task_id" in result: - print("\nChecking schema task completion...") - await test_llm_task_completion(result["task_id"]) - - await asyncio.sleep(2) # Be nice to the server - - print("\n=== Testing Error Cases ===") - # Test invalid URL - await test_endpoint( - "md", - "not_a_real_url", - expected_status=500 - ) - - # Test invalid filter type - await test_endpoint( - "md", - "example.com", - {"f": "invalid"}, - expected_status=422 - ) - - # Test LLM without query - await test_endpoint( - "llm", - "example.com" + +@pytest.mark.asyncio +@pytest.mark.timeout(60) # LLM tasks can take a while. +@pytest.mark.parametrize("url,params", markdown_params()) +async def test_markdown_endpoint(url: str, params: dict[str, str]): + response: EndpointResponse = await endpoint("md", url, params) + assert response, "Failed to process endpoint request" + assert isinstance(response, str), "Expected str response" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("url", TEST_URLS) +@pytest.mark.skip("LLM endpoint doesn't task based requests yet") +async def test_llm_endpoint_no_schema(url: str): + result: EndpointResponse = await endpoint( + "llm", url, {"q": "Extract title and main content"} ) - - # Test invalid task ID - await test_endpoint( - "llm", - "llm_invalid_task", - expected_status=404 + assert result, "Failed to process endpoint request" + assert isinstance(result, dict), "Expected dict response" + assert "task_id" in result + + print("\nChecking task completion...") + await llm_task_completion(result["task_id"]) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("url", TEST_URLS) +@pytest.mark.skip("LLM endpoint doesn't task based or schema requests yet") +async def test_llm_endpoint_schema(url: str): + schema = { + "type": "object", + "properties": { + "title": {"type": "string"}, + "content": {"type": "string"}, + "links": {"type": "array", "items": {"type": "string"}}, + }, + } + result: EndpointResponse = await endpoint( + "llm", + url, + { + "q": "Extract content with links", + "s": json.dumps(schema), + "c": "1", # Test with cache + }, ) - - print("\nAll tests completed!") + assert result, "Failed to process endpoint request" + assert isinstance(result, dict), "Expected dict response" + assert "task_id" in result + print("\nChecking schema task completion...") + await llm_task_completion(result["task_id"]) + + +@pytest.mark.asyncio +async def test_invalid_url(): + await endpoint("md", "not_a_real_url", expected_status=codes.INTERNAL_SERVER_ERROR) + + +@pytest.mark.asyncio +async def test_invalid_filter(): + await endpoint("md", "example.com", {"f": "invalid"}, expected_status=codes.UNPROCESSABLE_ENTITY) + + +@pytest.mark.asyncio +async def test_llm_without_query(): + await endpoint("llm", "example.com", expected_status=codes.BAD_REQUEST) + + +@pytest.mark.asyncio +async def test_invalid_task(): + await endpoint("llm", "llm_invalid_task", expected_status=codes.BAD_REQUEST) + if __name__ == "__main__": - asyncio.run(run_tests()) \ No newline at end of file + import subprocess + + sys.exit(subprocess.call(["pytest", *sys.argv[1:], sys.argv[0]])) diff --git a/tests/docker/test_server_token.py b/tests/docker/test_server_token.py index 220b6ca2c..0892866ac 100644 --- a/tests/docker/test_server_token.py +++ b/tests/docker/test_server_token.py @@ -1,212 +1,208 @@ -import asyncio import json -from typing import Optional +import sys +from typing import AsyncGenerator, Optional, Union from urllib.parse import quote -async def get_token(session, email: str = "test@example.com") -> str: +import aiohttp +import pytest +import pytest_asyncio +from httpx import AsyncClient, Response, codes + +from tests.docker.common import TEST_URLS, async_client, markdown_params + + +async def get_token(client: AsyncClient, email: str = "test@example.com") -> str: """Fetch a JWT token from the /token endpoint.""" - url = "http://localhost:8000/token" - payload = {"email": email} - print(f"\nFetching token from {url} with email: {email}") - try: - async with session.post(url, json=payload) as response: - status = response.status - data = await response.json() - print(f"Token Response Status: {status}") - print(f"Token Response: {json.dumps(data, indent=2)}") - if status == 200: - return data["access_token"] - else: - raise Exception(f"Failed to get token: {data.get('detail', 'Unknown error')}") - except Exception as e: - print(f"Error fetching token: {str(e)}") - raise - -async def test_endpoint( - session, - endpoint: str, - url: str, - token: str, - params: Optional[dict] = None, - expected_status: int = 200 -) -> Optional[dict]: - """Test an endpoint with token and print results.""" - params = params or {} - param_str = "&".join(f"{k}={v}" for k, v in params.items()) - full_url = f"http://localhost:8000/{endpoint}/{quote(url)}" - if param_str: - full_url += f"?{param_str}" - - headers = {"Authorization": f"Bearer {token}"} - print(f"\nTesting: {full_url}") - - try: - async with session.get(full_url, headers=headers) as response: - status = response.status - try: - data = await response.json() - except: - data = await response.text() - - print(f"Status: {status} (Expected: {expected_status})") - if isinstance(data, dict): - print(f"Response: {json.dumps(data, indent=2)}") - else: - print(f"Response: {data[:500]}...") # First 500 chars - assert status == expected_status, f"Expected {expected_status}, got {status}" - return data - except Exception as e: - print(f"Error: {str(e)}") - return None - - -async def test_stream_crawl(session, token: str): - """Test the /crawl/stream endpoint with multiple URLs.""" - url = "http://localhost:8000/crawl/stream" - payload = { - "urls": [ - "https://example.com", - "https://example.com/page1", # Replicated example.com with variation - "https://example.com/page2", # Replicated example.com with variation - "https://example.com/page3", # Replicated example.com with variation - # "https://www.python.org", - # "https://news.ycombinator.com/news" - ], - "browser_config": {"headless": True, "viewport": {"width": 1200}}, - "crawler_config": {"stream": True, "cache_mode": "bypass"} - } - headers = {"Authorization": f"Bearer {token}"} - print(f"\nTesting Streaming Crawl: {url}") - print(f"Payload: {json.dumps(payload, indent=2)}") - - try: - async with session.post(url, json=payload, headers=headers) as response: - status = response.status - print(f"Status: {status} (Expected: 200)") - assert status == 200, f"Expected 200, got {status}" - + path: str = "/token" + payload: dict[str, str] = {"email": email} + print(f"\nFetching token from {path} with email: {email}") + response: Response = await client.post(path, json=payload) + data = response.json() + print(f"Token Response Status: {response.status_code}") + print(f"Token Response: {json.dumps(data, indent=2)}") + assert response.status_code == codes.OK + return data["access_token"] + + +@pytest_asyncio.fixture(loop_scope="class") +async def setup_session() -> AsyncGenerator[tuple[AsyncClient, str], None]: + async with async_client() as client: + token: str = await get_token(client) + assert token, "Failed to get token" + + yield client, token + + +@pytest.mark.asyncio(loop_scope="class") +class TestAPI: + async def endpoint( + self, + client: AsyncClient, + token: str, + endpoint: str, + url: str, + params: Optional[dict] = None, + expected_status: int = codes.OK, + ) -> Union[dict, str]: + """Test an endpoint with token and print results.""" + path = f"/{endpoint}/{quote(url)}" + if params: + path += "?" + "&".join(f"{k}={v}" for k, v in params.items()) + + headers = {"Authorization": f"Bearer {token}"} + print(f"\nTesting: {path}") + + response: Response = await client.get(path, headers=headers) + content_type: str = response.headers.get(aiohttp.hdrs.CONTENT_TYPE, "").lower() + data: Union[dict, str] = ( + response.json() if content_type == "application/json" else response.text + ) + print(f"Status: {response.status_code} (Expected: {expected_status})") + if isinstance(data, dict): + print(f"Response: {json.dumps(data, indent=2)}") + else: + print(f"Response: {data[:500]}...") # First 500 chars + assert response.status_code == expected_status, ( + f"Expected {expected_status}, got {response.status_code}" + ) + return data + + async def stream_crawl(self, client: AsyncClient, token: str): + """Test the /crawl/stream endpoint with multiple URLs.""" + url = "/crawl/stream" + payload = { + "urls": [ + "https://example.com", + "https://example.com/page1", # Replicated example.com with variation + "https://example.com/page2", # Replicated example.com with variation + "https://example.com/page3", # Replicated example.com with variation + ], + "browser_config": {"headless": True, "viewport": {"width": 1200}}, + "crawler_config": {"stream": True, "cache_mode": "aggressive"}, + } + headers = {"Authorization": f"Bearer {token}"} + print(f"\nTesting Streaming Crawl: {url}") + print(f"Payload: {json.dumps(payload, indent=2)}") + + async with client.stream( + "POST", url, json=payload, headers=headers + ) as response: + print(f"Status: {response.status_code} (Expected: {codes.OK})") + assert response.status_code == codes.OK, ( + f"Expected {codes.OK}, got {response.status_code}" + ) + # Read streaming response line-by-line (NDJSON) - async for line in response.content: + async for line in response.aiter_lines(): if line: - data = json.loads(line.decode('utf-8').strip()) + data = json.loads(line.strip()) print(f"Streamed Result: {json.dumps(data, indent=2)}") - except Exception as e: - print(f"Error in streaming crawl test: {str(e)}") - -async def run_tests(): - import aiohttp - print("Starting API Tests...") - - # Test URLs - urls = [ - "example.com", - "https://www.python.org", - "https://news.ycombinator.com/news", - "https://github.com/trending" - ] - - async with aiohttp.ClientSession() as session: - # Fetch token once and reuse it - token = await get_token(session) - if not token: - print("Aborting tests due to token failure!") - return - - print("\n=== Testing Crawl Endpoint ===") + + async def test_crawl_endpoint(self, setup_session: tuple[AsyncClient, str]): + client, token = setup_session crawl_payload = { "urls": ["https://example.com"], "browser_config": {"headless": True}, - "crawler_config": {"stream": False} + "crawler_config": {"stream": False}, } - async with session.post( - "http://localhost:8000/crawl", + response = await client.post( + "/crawl", json=crawl_payload, - headers={"Authorization": f"Bearer {token}"} - ) as response: - status = response.status - data = await response.json() - print(f"\nCrawl Endpoint Status: {status}") - print(f"Crawl Response: {json.dumps(data, indent=2)}") - - - print("\n=== Testing Crawl Stream Endpoint ===") - await test_stream_crawl(session, token) - - print("\n=== Testing Markdown Endpoint ===") - for url in []: #urls: - for filter_type in ["raw", "fit", "bm25", "llm"]: - params = {"f": filter_type} - if filter_type in ["bm25", "llm"]: - params["q"] = "extract main content" - - for cache in ["0", "1"]: - params["c"] = cache - await test_endpoint(session, "md", url, token, params) - await asyncio.sleep(1) # Be nice to the server - - print("\n=== Testing LLM Endpoint ===") - for url in urls: - # Test basic extraction (direct response now) - result = await test_endpoint( - session, - "llm", - url, - token, - {"q": "Extract title and main content"} - ) - - # Test with schema (direct response) - schema = { - "type": "object", - "properties": { - "title": {"type": "string"}, - "content": {"type": "string"}, - "links": {"type": "array", "items": {"type": "string"}} - } - } - result = await test_endpoint( - session, - "llm", - url, - token, - { - "q": "Extract content with links", - "s": json.dumps(schema), - "c": "1" # Test with cache - } - ) - await asyncio.sleep(2) # Be nice to the server - - print("\n=== Testing Error Cases ===") - # Test invalid URL - await test_endpoint( - session, - "md", - "not_a_real_url", + headers={"Authorization": f"Bearer {token}"}, + ) + data = response.json() + print(f"\nCrawl Endpoint Status: {response.status_code}") + print(f"Crawl Response: {json.dumps(data, indent=2)}") + + @pytest.mark.asyncio + async def test_stream_endpoint(self, setup_session: tuple[AsyncClient, str]): + client, token = setup_session + await self.stream_crawl(client, token) + + @pytest.mark.asyncio + @pytest.mark.parametrize("url,params", markdown_params()) + @pytest.mark.timeout(60) # LLM tasks can take a while. + async def test_markdown_endpoint( + self, + url: str, + params: dict[str, str], + setup_session: tuple[AsyncClient, str], + ): + client, token = setup_session + result: Union[dict, str] = await self.endpoint(client, token, "md", url, params) + assert isinstance(result, str), "Expected str response" + assert result, "Expected non-empty response" + + @pytest.mark.parametrize("url", TEST_URLS) + async def test_llm_endpoint_no_schema( + self, url: str, setup_session: tuple[AsyncClient, str] + ): + client, token = setup_session + # Test basic extraction (direct response now) + result = await self.endpoint( + client, token, - expected_status=500 + "llm", + url, + {"q": "Extract title and main content"}, + ) + assert isinstance(result, dict), "Expected dict response" + assert "answer" in result, "Expected 'answer' key in response" + + # Currently the server handles LLM requests using handle_llm_qa + # which doesn't use handle_llm_request which is where the schema + # is processed. + @pytest.mark.parametrize("url", TEST_URLS) + @pytest.mark.skip("LLM endpoint doesn't schema request yet") + async def test_llm_endpoint_schema( + self, url: str, setup_session: tuple[AsyncClient, str] + ): + client, token = setup_session + schema = { + "type": "object", + "properties": { + "title": {"type": "string"}, + "content": {"type": "string"}, + "links": {"type": "array", "items": {"type": "string"}}, + }, + } + result = await self.endpoint( + client, + token, + "llm", + url, + { + "q": "Extract content with links", + "s": json.dumps(schema), + "c": "1", # Test with cache + }, ) - - # Test invalid filter type - await test_endpoint( - session, + assert isinstance(result, dict), "Expected dict response" + assert "answer" in result, "Expected 'answer' key in response" + print(result) + + async def test_invalid_url(self, setup_session: tuple[AsyncClient, str]): + client, token = setup_session + await self.endpoint(client, token, "md", "not_a_real_url", expected_status=codes.INTERNAL_SERVER_ERROR) + + async def test_invalid_filter(self, setup_session: tuple[AsyncClient, str]): + client, token = setup_session + await self.endpoint( + client, + token, "md", "example.com", - token, {"f": "invalid"}, - expected_status=422 + expected_status=codes.UNPROCESSABLE_ENTITY, ) - + + async def test_missing_query(self, setup_session: tuple[AsyncClient, str]): + client, token = setup_session # Test LLM without query (should fail per your server logic) - await test_endpoint( - session, - "llm", - "example.com", - token, - expected_status=400 - ) - - print("\nAll tests completed!") + await self.endpoint(client, token, "llm", "example.com", expected_status=codes.BAD_REQUEST) + if __name__ == "__main__": - asyncio.run(run_tests()) \ No newline at end of file + import subprocess + + sys.exit(subprocess.call(["pytest", *sys.argv[1:], sys.argv[0]])) diff --git a/tests/docker_example.py b/tests/docker_example.py index 336ca52f6..49b878d9e 100644 --- a/tests/docker_example.py +++ b/tests/docker_example.py @@ -1,14 +1,17 @@ +from httpx import codes import requests import json import time import sys import base64 import os -from typing import Dict, Any +from typing import Dict, Any, Optional class Crawl4AiTester: - def __init__(self, base_url: str = "http://localhost:11235", api_token: str = None): + def __init__( + self, base_url: str = "http://localhost:11235", api_token: Optional[str] = None + ): self.base_url = base_url self.api_token = api_token or os.getenv( "CRAWL4AI_API_TOKEN" @@ -24,7 +27,7 @@ def submit_and_wait( response = requests.post( f"{self.base_url}/crawl", json=request_data, headers=self.headers ) - if response.status_code == 403: + if response.status_code == codes.FORBIDDEN: raise Exception("API token is invalid or missing") task_id = response.json()["task_id"] print(f"Task ID: {task_id}") @@ -58,7 +61,7 @@ def submit_sync(self, request_data: Dict[str, Any]) -> Dict[str, Any]: headers=self.headers, timeout=60, ) - if response.status_code == 408: + if response.status_code == codes.REQUEST_TIMEOUT: raise TimeoutError("Task did not complete within server timeout") response.raise_for_status() return response.json() @@ -66,7 +69,6 @@ def submit_sync(self, request_data: Dict[str, Any]) -> Dict[str, Any]: def test_docker_deployment(version="basic"): tester = Crawl4AiTester( - # base_url="http://localhost:11235" , base_url="https://crawl4ai-sby74.ondigitalocean.app", api_token="test", ) diff --git a/tests/hub/test_simple.py b/tests/hub/test_simple.py index a970d683c..a3d45b174 100644 --- a/tests/hub/test_simple.py +++ b/tests/hub/test_simple.py @@ -1,34 +1,50 @@ # test.py -from crawl4ai import CrawlerHub import json +import sys +from pathlib import Path + +import pytest + +from crawl4ai import CrawlerHub + -async def amazon_example(): - if (crawler_cls := CrawlerHub.get("amazon_product")) : - crawler = crawler_cls() - print(f"Crawler version: {crawler_cls.meta['version']}") - print(f"Rate limits: {crawler_cls.meta.get('rate_limit', 'Unlimited')}") - print(await crawler.run("https://amazon.com/test")) - else: - print("Crawler not found!") +@pytest.mark.asyncio +async def test_amazon(): + crawler_cls = CrawlerHub.get("amazon_product") + assert crawler_cls is not None -async def google_example(): + crawler = crawler_cls() + print(f"Crawler version: {crawler.meta['version']}") + print(f"Rate limits: {crawler.meta.get('rate_limit', 'Unlimited')}") + result = await crawler.run("https://amazon.com/test") + assert result and "product" in result + print(result) + + +@pytest.mark.asyncio +@pytest.mark.skip("crawler doesn't pass llm_config to generate_schema, so it fails") +async def test_google(tmp_path: Path): # Get crawler dynamically crawler_cls = CrawlerHub.get("google_search") + assert crawler_cls is not None crawler = crawler_cls() # Text search + schema_cache_path: Path = tmp_path / ".crawl4ai" text_results = await crawler.run( - query="apple inc", - search_type="text", - schema_cache_path="/Users/unclecode/.crawl4ai" + query="apple inc", + search_type="text", + schema_cache_path=schema_cache_path.as_posix(), ) print(json.dumps(json.loads(text_results), indent=4)) + assert text_results # Image search # image_results = await crawler.run(query="apple inc", search_type="image") # print(image_results) + if __name__ == "__main__": - import asyncio - # asyncio.run(amazon_example()) - asyncio.run(google_example()) \ No newline at end of file + import subprocess + + sys.exit(subprocess.call(["pytest", *sys.argv[1:], sys.argv[0]])) diff --git a/tests/legacy/test_cli_docs.py b/tests/legacy/test_cli_docs.py new file mode 100644 index 000000000..be70d7c80 --- /dev/null +++ b/tests/legacy/test_cli_docs.py @@ -0,0 +1,20 @@ +import sys + +import pytest + +from crawl4ai.legacy.docs_manager import DocsManager + + +@pytest.mark.asyncio +async def test_cli(): + """Test all CLI commands""" + print("\n1. Testing docs update...") + docs_manager = DocsManager() + result = await docs_manager.fetch_docs() + assert result + + +if __name__ == "__main__": + import subprocess + + sys.exit(subprocess.call(["pytest", *sys.argv[1:], sys.argv[0]])) diff --git a/tests/loggers/test_logger.py b/tests/loggers/test_logger.py index 6c3a811b4..a8a074592 100644 --- a/tests/loggers/test_logger.py +++ b/tests/loggers/test_logger.py @@ -1,7 +1,18 @@ -import asyncio -from crawl4ai import AsyncWebCrawler, BrowserConfig, CrawlerRunConfig, CacheMode, AsyncLoggerBase import os +import sys from datetime import datetime +from pathlib import Path + +import pytest + +from crawl4ai import ( + AsyncLoggerBase, + AsyncWebCrawler, + BrowserConfig, + CacheMode, + CrawlerRunConfig, +) + class AsyncFileLogger(AsyncLoggerBase): """ @@ -55,26 +66,25 @@ def error_status(self, url: str, error: str, tag: str = "ERROR", url_length: int message = f"{url[:url_length]}... | Error: {error}" self._write_to_file("ERROR", message, tag) -async def main(): + +@pytest.mark.asyncio +async def test_logger(tmp_path: Path): + log_file: Path = tmp_path / "test.log" browser_config = BrowserConfig(headless=True, verbose=True) - crawler = AsyncWebCrawler(config=browser_config, logger=AsyncFileLogger("/Users/unclecode/devs/crawl4ai/.private/tmp/crawl.log")) - await crawler.start() - - try: + async with AsyncWebCrawler(config=browser_config,logger=AsyncFileLogger(log_file.as_posix())) as crawler: crawl_config = CrawlerRunConfig( cache_mode=CacheMode.BYPASS, ) + # Use the crawler multiple times - result = await crawler.arun( - url='https://kidocode.com/', - config=crawl_config - ) - if result.success: - print("First crawl - Raw Markdown Length:", len(result.markdown.raw_markdown)) - - finally: - # Always ensure we close the crawler - await crawler.close() + for _ in range(3): + result = await crawler.arun( + url='https://kidocode.com/', + config=crawl_config + ) + assert result.success if __name__ == "__main__": - asyncio.run(main()) + import subprocess + + sys.exit(subprocess.call(["pytest", *sys.argv[1:], sys.argv[0]])) diff --git a/tests/memory/test_crawler_monitor.py b/tests/memory/test_crawler_monitor.py index 89cc08b84..b186bf630 100644 --- a/tests/memory/test_crawler_monitor.py +++ b/tests/memory/test_crawler_monitor.py @@ -9,6 +9,7 @@ import threading import sys import os +import pytest # Add the parent directory to the path to import crawl4ai sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) @@ -59,7 +60,7 @@ def simulate_crawler_task(monitor, task_id, url, simulate_failure=False): memory_usage=0.0 ) -def update_queue_stats(monitor, num_queued_tasks): +def update_queue_stats(monitor: CrawlerMonitor, num_queued_tasks): """Update queue statistics periodically.""" while monitor.is_running: queued_tasks = [ @@ -102,6 +103,7 @@ def update_queue_stats(monitor, num_queued_tasks): time.sleep(1.0) +@pytest.mark.timeout(60) def test_crawler_monitor(): """Test the CrawlerMonitor with simulated crawler tasks.""" # Total number of URLs to crawl @@ -165,4 +167,6 @@ def test_crawler_monitor(): print("\nCrawler monitor test completed") if __name__ == "__main__": + # Doesn't use the standard pytest entry point as it's not compatible + # with the tty output in the terminal. test_crawler_monitor() \ No newline at end of file diff --git a/tests/memory/test_dispatcher_stress.py b/tests/memory/test_dispatcher_stress.py index f81f78f65..684546cc2 100644 --- a/tests/memory/test_dispatcher_stress.py +++ b/tests/memory/test_dispatcher_stress.py @@ -3,13 +3,14 @@ import psutil import logging import random -from typing import List, Dict +from typing import List, Dict, Optional import uuid import sys import os +import pytest # Import your crawler components -from crawl4ai.models import DisplayMode, CrawlStatus, CrawlResult +from crawl4ai.models import CrawlStatus from crawl4ai.async_configs import CrawlerRunConfig, BrowserConfig, CacheMode from crawl4ai import AsyncWebCrawler from crawl4ai import MemoryAdaptiveDispatcher, CrawlerMonitor @@ -73,8 +74,8 @@ def apply_pressure(self, additional_percent: float = 0.0): time.sleep(0.5) # Give system time to register the allocation except MemoryError: logger.warning("Unable to allocate more memory") - - def release_pressure(self, percent: float = None): + + def release_pressure(self, percent: Optional[float] = None): """ Release allocated memory If percent is specified, release that percentage of blocks @@ -101,24 +102,24 @@ def spike_pressure(self, duration: float = 5.0): logger.info(f"Creating memory pressure spike for {duration} seconds") # Save current blocks count initial_blocks = len(self.memory_blocks) - + # Create spike with extra 5% self.apply_pressure(additional_percent=5.0) - + # Schedule release after duration asyncio.create_task(self._delayed_release(duration, initial_blocks)) - + async def _delayed_release(self, delay: float, target_blocks: int): """Helper for spike_pressure - releases extra blocks after delay""" await asyncio.sleep(delay) - + # Remove blocks added since spike started if len(self.memory_blocks) > target_blocks: logger.info(f"Releasing memory spike ({len(self.memory_blocks) - target_blocks} blocks)") self.memory_blocks = self.memory_blocks[:target_blocks] - + # Test statistics collector -class TestResults: +class StressTestResults: def __init__(self): self.start_time = time.time() self.completed_urls: List[str] = [] @@ -167,7 +168,7 @@ def log_summary(self): # Custom monitor with stats tracking # Custom monitor that extends CrawlerMonitor with test-specific tracking class StressTestMonitor(CrawlerMonitor): - def __init__(self, test_results: TestResults, **kwargs): + def __init__(self, test_results: StressTestResults, **kwargs): # Initialize the parent CrawlerMonitor super().__init__(**kwargs) self.test_results = test_results @@ -192,32 +193,51 @@ def update_queue_statistics(self, total_queued: int, highest_wait_time: float, a # Call parent method to update the dashboard super().update_queue_statistics(total_queued, highest_wait_time, avg_wait_time) - - def update_task(self, task_id: str, **kwargs): + + def update_task( + self, + task_id: str, + status: Optional[CrawlStatus] = None, + start_time: Optional[float] = None, + end_time: Optional[float] = None, + memory_usage: Optional[float] = None, + peak_memory: Optional[float] = None, + error_message: Optional[str] = None, + retry_count: Optional[int] = None, + wait_time: Optional[float] = None + ): # Track URL status changes for test results if task_id in self.stats: - old_status = self.stats[task_id].status - + old_status = self.stats[task_id]["status"] + # If this is a requeue event (requeued due to memory pressure) - if 'error_message' in kwargs and 'requeued' in kwargs['error_message']: - if not hasattr(self.stats[task_id], 'counted_requeue') or not self.stats[task_id].counted_requeue: + if error_message and 'requeued' in error_message: + if not self.stats[task_id]["counted_requeue"]: self.test_results.requeued_count += 1 - self.stats[task_id].counted_requeue = True - + self.stats[task_id]["counted_requeue"] = True + # Track completion status for test results - if 'status' in kwargs: - new_status = kwargs['status'] - if old_status != new_status: - if new_status == CrawlStatus.COMPLETED: + if status: + if old_status != status: + if status == CrawlStatus.COMPLETED: if task_id not in self.test_results.completed_urls: self.test_results.completed_urls.append(task_id) - elif new_status == CrawlStatus.FAILED: + elif status == CrawlStatus.FAILED: if task_id not in self.test_results.failed_urls: self.test_results.failed_urls.append(task_id) # Call parent method to update the dashboard - super().update_task(task_id, **kwargs) - self.live.update(self._create_table()) + super().update_task( + task_id=task_id, + status=status, + start_time=start_time, + end_time=end_time, + memory_usage=memory_usage, + peak_memory=peak_memory, + error_message=error_message, + retry_count=retry_count, + wait_time=wait_time + ) # Generate test URLs - use example.com with unique paths to avoid browser caching def generate_test_urls(count: int) -> List[str]: @@ -230,7 +250,7 @@ def generate_test_urls(count: int) -> List[str]: return urls # Process result callback -async def process_result(result, test_results: TestResults): +async def process_result(result, test_results: StressTestResults): # Track attempt counts if result.url not in test_results.url_to_attempt: test_results.url_to_attempt[result.url] = 1 @@ -248,7 +268,7 @@ async def process_result(result, test_results: TestResults): logger.warning(f"Failed to process: {result.url} - {result.error_message}") # Process multiple results (used in non-streaming mode) -async def process_results(results, test_results: TestResults): +async def process_results(results, test_results: StressTestResults): for result in results: await process_result(result, test_results) @@ -260,7 +280,25 @@ async def run_memory_stress_test( aggressive: bool = False, spikes: bool = True ): - test_results = TestResults() + # Scale values based on initial memory usage. + # With no initial memory pressure, we can use the default values: + # - 55% is the threshold for normal operation - plenty of memory + # - 63% is the threshold for throttling + # - 70% is the threshold for requeuing - incredibly aggressive + initial_percent: float = psutil.virtual_memory().percent + baseline: float = 55.0 + if initial_percent > baseline: + baseline = initial_percent + 5.0 + + memory_threshold_percent = baseline + 8 + critical_threshold_percent = baseline + 15 + recovery_threshold_percent = baseline + + if critical_threshold_percent > 99.0: + pytest.skip("Memory pressure too high for this system") + + test_results = StressTestResults() + memory_simulator = MemorySimulator(target_percent=target_memory_percent, aggressive=aggressive) logger.info(f"Starting stress test with {url_count} URLs in {'STREAM' if STREAM else 'NON-STREAM'} mode") @@ -285,17 +323,15 @@ async def run_memory_stress_test( # Create monitor with reference to test results monitor = StressTestMonitor( test_results=test_results, - display_mode=DisplayMode.DETAILED, - max_visible_rows=20, - total_urls=url_count # Pass total URLs count + urls_total=url_count # Pass total URLs count ) # Create dispatcher with EXTREME settings - pure survival mode # These settings are designed to create a memory battleground dispatcher = MemoryAdaptiveDispatcher( - memory_threshold_percent=63.0, # Start throttling at just 60% memory - critical_threshold_percent=70.0, # Start requeuing at 70% - incredibly aggressive - recovery_threshold_percent=55.0, # Only resume normal ops when plenty of memory available + memory_threshold_percent=memory_threshold_percent, + critical_threshold_percent=critical_threshold_percent, + recovery_threshold_percent=recovery_threshold_percent, check_interval=0.1, # Check extremely frequently (100ms) max_session_permit=20 if aggressive else 10, # Double the concurrent sessions - pure chaos fairness_timeout=10.0, # Extremely low timeout - rapid priority changes @@ -303,8 +339,8 @@ async def run_memory_stress_test( ) # Set up spike schedule if enabled + spike_intervals = [] if spikes: - spike_intervals = [] # Create 3-5 random spike times num_spikes = random.randint(3, 5) for _ in range(num_spikes): @@ -379,6 +415,12 @@ async def run_memory_stress_test( logger.info("TEST PASSED: All URLs were processed without crashing.") return True +@pytest.mark.asyncio +@pytest.mark.timeout(600) +async def test_memory_stress(): + """Run the memory stress test with default parameters.""" + assert await run_memory_stress_test() + # Command-line entry point if __name__ == "__main__": # Parse command line arguments diff --git a/tests/test_cli_docs.py b/tests/test_cli_docs.py deleted file mode 100644 index 6941f20db..000000000 --- a/tests/test_cli_docs.py +++ /dev/null @@ -1,44 +0,0 @@ -import asyncio -from crawl4ai.docs_manager import DocsManager -from click.testing import CliRunner -from crawl4ai.cli import cli - - -def test_cli(): - """Test all CLI commands""" - runner = CliRunner() - - print("\n1. Testing docs update...") - # Use sync version for testing - docs_manager = DocsManager() - loop = asyncio.get_event_loop() - loop.run_until_complete(docs_manager.fetch_docs()) - - # print("\n2. Testing listing...") - # result = runner.invoke(cli, ['docs', 'list']) - # print(f"Status: {'✅' if result.exit_code == 0 else '❌'}") - # print(result.output) - - # print("\n2. Testing index building...") - # result = runner.invoke(cli, ['docs', 'index']) - # print(f"Status: {'✅' if result.exit_code == 0 else '❌'}") - # print(f"Output: {result.output}") - - # print("\n3. Testing search...") - # result = runner.invoke(cli, ['docs', 'search', 'how to use crawler', '--build-index']) - # print(f"Status: {'✅' if result.exit_code == 0 else '❌'}") - # print(f"First 200 chars: {result.output[:200]}...") - - # print("\n4. Testing combine with sections...") - # result = runner.invoke(cli, ['docs', 'combine', 'chunking_strategies', 'extraction_strategies', '--mode', 'extended']) - # print(f"Status: {'✅' if result.exit_code == 0 else '❌'}") - # print(f"First 200 chars: {result.output[:200]}...") - - print("\n5. Testing combine all sections...") - result = runner.invoke(cli, ["docs", "combine", "--mode", "condensed"]) - print(f"Status: {'✅' if result.exit_code == 0 else '❌'}") - print(f"First 200 chars: {result.output[:200]}...") - - -if __name__ == "__main__": - test_cli() diff --git a/tests/test_crawl_result_container.py b/tests/test_crawl_result_container.py new file mode 100644 index 000000000..84cd3b7ae --- /dev/null +++ b/tests/test_crawl_result_container.py @@ -0,0 +1,133 @@ +from typing import Any, AsyncGenerator + +import pytest +from _pytest.mark.structures import ParameterSet + +from crawl4ai.models import CrawlResultContainer, CrawlResult + +RESULT: CrawlResult = CrawlResult( + url="https://example.com", success=True, html="Test content" +) + + +def result_container_params() -> list[ParameterSet]: + """Return a list of test parameters for the CrawlResultContainer tests. + + :return: List of test parameters for CrawlResultContainer tests containing tuple[result:CrawlResultContainer, expected:list[CrawlResult]] + :rtype: list[ParameterSet] + """ + + async def async_generator(results: list[CrawlResult]) -> AsyncGenerator[CrawlResult, Any]: + for result in results: + yield result + + return [ + pytest.param(CrawlResultContainer(RESULT), [RESULT], id="result"), + pytest.param(CrawlResultContainer([]), [], id="list_empty"), + pytest.param(CrawlResultContainer([RESULT]), [RESULT], id="list_single"), + pytest.param(CrawlResultContainer([RESULT, RESULT]), [RESULT, RESULT], id="list_multi"), + pytest.param(CrawlResultContainer(async_generator([])), [], id="async_empty"), + pytest.param(CrawlResultContainer(async_generator([RESULT])), [RESULT], id="async_single"), + pytest.param(CrawlResultContainer(async_generator([RESULT, RESULT])), [RESULT, RESULT], id="async_multi"), + ] + + +@pytest.mark.parametrize("result,expected", result_container_params()) +def test_iter(result: CrawlResultContainer, expected: list[CrawlResult]): + """Test __iter__ of the CrawlResultContainer.""" + if isinstance(result.source, AsyncGenerator): + with pytest.raises(TypeError): + for item in result: + pass + return + + results: list[CrawlResult] = [] + for item in result: + results.append(item) + + assert results == expected + + +@pytest.mark.asyncio +@pytest.mark.parametrize("result,expected", result_container_params()) +async def test_aiter(result: CrawlResultContainer, expected: list[CrawlResult]): + """Test __aiter__ of the CrawlResultContainer.""" + results: list[CrawlResult] = [] + async for item in result: + results.append(item) + + assert results == expected + + +@pytest.mark.parametrize("result,expected", result_container_params()) +def test_getitem(result: CrawlResultContainer, expected: list[CrawlResult]): + """Test the __getitem__ method of the CrawlResultContainer.""" + if isinstance(result.source, AsyncGenerator): + with pytest.raises(TypeError): + assert result[0] == expected[0] + + return + + for i in range(len(expected)): + assert result[i] == expected[i] + + +@pytest.mark.parametrize("result,expected", result_container_params()) +def test_len(result: CrawlResultContainer, expected: list[CrawlResult]): + """Test the __len__ of the CrawlResultContainer.""" + if isinstance(result.source, AsyncGenerator): + with pytest.raises(TypeError): + assert len(result) == len(expected) + + return + + assert len(result) == len(expected) + + +def result_attributes() -> list[str]: + """Return a list of attributes to test for CrawlResult + + :return: List of valid attributes, excluding private, callable, and deprecated attributes. + :rtype: list[str] + """ + + # We check hasattr to avoid class only attribute error and failing on deprecated attributes. + return [ + attr + for attr in dir(RESULT) + if attr in RESULT.model_fields or (hasattr(RESULT, attr) and isinstance(getattr(RESULT, attr), property)) + ] + + +@pytest.mark.parametrize("result,expected", result_container_params()) +def test_getattribute(result: CrawlResultContainer, expected: list[CrawlResult]): + """Test the __getattribute__ method of the CrawlResultContainer.""" + assert result.source is not None + + if isinstance(result.source, AsyncGenerator): + for attr in result_attributes(): + with pytest.raises(TypeError): + assert getattr(result, attr) == getattr(RESULT, attr) + + return + + if not expected: + for attr in result_attributes(): + with pytest.raises(AttributeError): + assert getattr(result, attr) == getattr(RESULT, attr) + + return + + for attr in result_attributes(): + assert getattr(result, attr) == getattr(RESULT, attr) + + +@pytest.mark.parametrize("result,expected", result_container_params()) +def test_repr(result: CrawlResultContainer, expected: list[CrawlResult]): + """Test the __repr__ method of the CrawlResultContainer.""" + if isinstance(result.source, AsyncGenerator): + assert repr(result) == "CrawlResultContainer([])" + + return + + assert repr(result) == f"CrawlResultContainer({repr(expected)})" diff --git a/tests/test_docker.py b/tests/test_docker.py deleted file mode 100644 index 3570d608d..000000000 --- a/tests/test_docker.py +++ /dev/null @@ -1,299 +0,0 @@ -import requests -import json -import time -import sys -import base64 -import os -from typing import Dict, Any - - -class Crawl4AiTester: - def __init__(self, base_url: str = "http://localhost:11235"): - self.base_url = base_url - - def submit_and_wait( - self, request_data: Dict[str, Any], timeout: int = 300 - ) -> Dict[str, Any]: - # Submit crawl job - response = requests.post(f"{self.base_url}/crawl", json=request_data) - task_id = response.json()["task_id"] - print(f"Task ID: {task_id}") - - # Poll for result - start_time = time.time() - while True: - if time.time() - start_time > timeout: - raise TimeoutError( - f"Task {task_id} did not complete within {timeout} seconds" - ) - - result = requests.get(f"{self.base_url}/task/{task_id}") - status = result.json() - - if status["status"] == "failed": - print("Task failed:", status.get("error")) - raise Exception(f"Task failed: {status.get('error')}") - - if status["status"] == "completed": - return status - - time.sleep(2) - - -def test_docker_deployment(version="basic"): - tester = Crawl4AiTester() - print(f"Testing Crawl4AI Docker {version} version") - - # Health check with timeout and retry - max_retries = 5 - for i in range(max_retries): - try: - health = requests.get(f"{tester.base_url}/health", timeout=10) - print("Health check:", health.json()) - break - except requests.exceptions.RequestException: - if i == max_retries - 1: - print(f"Failed to connect after {max_retries} attempts") - sys.exit(1) - print(f"Waiting for service to start (attempt {i+1}/{max_retries})...") - time.sleep(5) - - # Test cases based on version - test_basic_crawl(tester) - - # if version in ["full", "transformer"]: - # test_cosine_extraction(tester) - - # test_js_execution(tester) - # test_css_selector(tester) - # test_structured_extraction(tester) - # test_llm_extraction(tester) - # test_llm_with_ollama(tester) - # test_screenshot(tester) - - -def test_basic_crawl(tester: Crawl4AiTester): - print("\n=== Testing Basic Crawl ===") - request = {"urls": "https://www.nbcnews.com/business", "priority": 10} - - result = tester.submit_and_wait(request) - print(f"Basic crawl result length: {len(result['result']['markdown'])}") - assert result["result"]["success"] - assert len(result["result"]["markdown"]) > 0 - - -def test_js_execution(tester: Crawl4AiTester): - print("\n=== Testing JS Execution ===") - request = { - "urls": "https://www.nbcnews.com/business", - "priority": 8, - "js_code": [ - "const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More')); loadMoreButton && loadMoreButton.click();" - ], - "wait_for": "article.tease-card:nth-child(10)", - "crawler_params": {"headless": True}, - } - - result = tester.submit_and_wait(request) - print(f"JS execution result length: {len(result['result']['markdown'])}") - assert result["result"]["success"] - - -def test_css_selector(tester: Crawl4AiTester): - print("\n=== Testing CSS Selector ===") - request = { - "urls": "https://www.nbcnews.com/business", - "priority": 7, - "css_selector": ".wide-tease-item__description", - "crawler_params": {"headless": True}, - "extra": {"word_count_threshold": 10}, - } - - result = tester.submit_and_wait(request) - print(f"CSS selector result length: {len(result['result']['markdown'])}") - assert result["result"]["success"] - - -def test_structured_extraction(tester: Crawl4AiTester): - print("\n=== Testing Structured Extraction ===") - schema = { - "name": "Coinbase Crypto Prices", - "baseSelector": ".cds-tableRow-t45thuk", - "fields": [ - { - "name": "crypto", - "selector": "td:nth-child(1) h2", - "type": "text", - }, - { - "name": "symbol", - "selector": "td:nth-child(1) p", - "type": "text", - }, - { - "name": "price", - "selector": "td:nth-child(2)", - "type": "text", - }, - ], - } - - request = { - "urls": "https://www.coinbase.com/explore", - "priority": 9, - "extraction_config": {"type": "json_css", "params": {"schema": schema}}, - } - - result = tester.submit_and_wait(request) - extracted = json.loads(result["result"]["extracted_content"]) - print(f"Extracted {len(extracted)} items") - print("Sample item:", json.dumps(extracted[0], indent=2)) - assert result["result"]["success"] - assert len(extracted) > 0 - - -def test_llm_extraction(tester: Crawl4AiTester): - print("\n=== Testing LLM Extraction ===") - schema = { - "type": "object", - "properties": { - "model_name": { - "type": "string", - "description": "Name of the OpenAI model.", - }, - "input_fee": { - "type": "string", - "description": "Fee for input token for the OpenAI model.", - }, - "output_fee": { - "type": "string", - "description": "Fee for output token for the OpenAI model.", - }, - }, - "required": ["model_name", "input_fee", "output_fee"], - } - - request = { - "urls": "https://openai.com/api/pricing", - "priority": 8, - "extraction_config": { - "type": "llm", - "params": { - "provider": "openai/gpt-4o-mini", - "api_token": os.getenv("OPENAI_API_KEY"), - "schema": schema, - "extraction_type": "schema", - "instruction": """From the crawled content, extract all mentioned model names along with their fees for input and output tokens.""", - }, - }, - "crawler_params": {"word_count_threshold": 1}, - } - - try: - result = tester.submit_and_wait(request) - extracted = json.loads(result["result"]["extracted_content"]) - print(f"Extracted {len(extracted)} model pricing entries") - print("Sample entry:", json.dumps(extracted[0], indent=2)) - assert result["result"]["success"] - except Exception as e: - print(f"LLM extraction test failed (might be due to missing API key): {str(e)}") - - -def test_llm_with_ollama(tester: Crawl4AiTester): - print("\n=== Testing LLM with Ollama ===") - schema = { - "type": "object", - "properties": { - "article_title": { - "type": "string", - "description": "The main title of the news article", - }, - "summary": { - "type": "string", - "description": "A brief summary of the article content", - }, - "main_topics": { - "type": "array", - "items": {"type": "string"}, - "description": "Main topics or themes discussed in the article", - }, - }, - } - - request = { - "urls": "https://www.nbcnews.com/business", - "priority": 8, - "extraction_config": { - "type": "llm", - "params": { - "provider": "ollama/llama2", - "schema": schema, - "extraction_type": "schema", - "instruction": "Extract the main article information including title, summary, and main topics.", - }, - }, - "extra": {"word_count_threshold": 1}, - "crawler_params": {"verbose": True}, - } - - try: - result = tester.submit_and_wait(request) - extracted = json.loads(result["result"]["extracted_content"]) - print("Extracted content:", json.dumps(extracted, indent=2)) - assert result["result"]["success"] - except Exception as e: - print(f"Ollama extraction test failed: {str(e)}") - - -def test_cosine_extraction(tester: Crawl4AiTester): - print("\n=== Testing Cosine Extraction ===") - request = { - "urls": "https://www.nbcnews.com/business", - "priority": 8, - "extraction_config": { - "type": "cosine", - "params": { - "semantic_filter": "business finance economy", - "word_count_threshold": 10, - "max_dist": 0.2, - "top_k": 3, - }, - }, - } - - try: - result = tester.submit_and_wait(request) - extracted = json.loads(result["result"]["extracted_content"]) - print(f"Extracted {len(extracted)} text clusters") - print("First cluster tags:", extracted[0]["tags"]) - assert result["result"]["success"] - except Exception as e: - print(f"Cosine extraction test failed: {str(e)}") - - -def test_screenshot(tester: Crawl4AiTester): - print("\n=== Testing Screenshot ===") - request = { - "urls": "https://www.nbcnews.com/business", - "priority": 5, - "screenshot": True, - "crawler_params": {"headless": True}, - } - - result = tester.submit_and_wait(request) - print("Screenshot captured:", bool(result["result"]["screenshot"])) - - if result["result"]["screenshot"]: - # Save screenshot - screenshot_data = base64.b64decode(result["result"]["screenshot"]) - with open("test_screenshot.jpg", "wb") as f: - f.write(screenshot_data) - print("Screenshot saved as test_screenshot.jpg") - - assert result["result"]["success"] - - -if __name__ == "__main__": - version = sys.argv[1] if len(sys.argv) > 1 else "basic" - # version = "full" - test_docker_deployment(version) diff --git a/tests/test_llmtxt.py b/tests/test_llmtxt.py index 2cdb02715..6910a2f43 100644 --- a/tests/test_llmtxt.py +++ b/tests/test_llmtxt.py @@ -1,12 +1,15 @@ -from crawl4ai.llmtxt import AsyncLLMTextManager # Changed to AsyncLLMTextManager -from crawl4ai.async_logger import AsyncLogger +import sys from pathlib import Path -import asyncio +import pytest + +from crawl4ai.async_logger import AsyncLogger +from crawl4ai.legacy.llmtxt import AsyncLLMTextManager -async def main(): + +@pytest.mark.asyncio +async def test_llm_txt(): current_file = Path(__file__).resolve() - # base_dir = current_file.parent.parent / "local/_docs/llm.txt/test_docs" base_dir = current_file.parent.parent / "local/_docs/llm.txt" docs_dir = base_dir @@ -16,7 +19,6 @@ async def main(): # Initialize logger logger = AsyncLogger() # Updated initialization with default batching params - # manager = AsyncLLMTextManager(docs_dir, logger, max_concurrent_calls=3, batch_size=2) manager = AsyncLLMTextManager(docs_dir, logger, batch_size=2) # Let's first check what files we have @@ -39,14 +41,13 @@ async def main(): for query in test_queries: print(f"\nQuery: {query}") results = manager.search(query, top_k=2) - print(f"Results length: {len(results)} characters") - if results: - print( - "First 200 chars of results:", results[:200].replace("\n", " "), "..." - ) - else: - print("No results found") + assert results, "No results found" + print( + "First 200 chars of results:", results[:200].replace("\n", " "), "..." + ) if __name__ == "__main__": - asyncio.run(main()) + import subprocess + + sys.exit(subprocess.call(["pytest", *sys.argv[1:], sys.argv[0]])) diff --git a/tests/test_main.py b/tests/test_main.py deleted file mode 100644 index 0e938f590..000000000 --- a/tests/test_main.py +++ /dev/null @@ -1,276 +0,0 @@ -import asyncio -import aiohttp -import json -import time -import os -from typing import Dict, Any - - -class NBCNewsAPITest: - def __init__(self, base_url: str = "http://localhost:8000"): - self.base_url = base_url - self.session = None - - async def __aenter__(self): - self.session = aiohttp.ClientSession() - return self - - async def __aexit__(self, exc_type, exc_val, exc_tb): - if self.session: - await self.session.close() - - async def submit_crawl(self, request_data: Dict[str, Any]) -> str: - async with self.session.post( - f"{self.base_url}/crawl", json=request_data - ) as response: - result = await response.json() - return result["task_id"] - - async def get_task_status(self, task_id: str) -> Dict[str, Any]: - async with self.session.get(f"{self.base_url}/task/{task_id}") as response: - return await response.json() - - async def wait_for_task( - self, task_id: str, timeout: int = 300, poll_interval: int = 2 - ) -> Dict[str, Any]: - start_time = time.time() - while True: - if time.time() - start_time > timeout: - raise TimeoutError( - f"Task {task_id} did not complete within {timeout} seconds" - ) - - status = await self.get_task_status(task_id) - if status["status"] in ["completed", "failed"]: - return status - - await asyncio.sleep(poll_interval) - - async def check_health(self) -> Dict[str, Any]: - async with self.session.get(f"{self.base_url}/health") as response: - return await response.json() - - -async def test_basic_crawl(): - print("\n=== Testing Basic Crawl ===") - async with NBCNewsAPITest() as api: - request = {"urls": "https://www.nbcnews.com/business", "priority": 10} - task_id = await api.submit_crawl(request) - result = await api.wait_for_task(task_id) - print(f"Basic crawl result length: {len(result['result']['markdown'])}") - assert result["status"] == "completed" - assert "result" in result - assert result["result"]["success"] - - -async def test_js_execution(): - print("\n=== Testing JS Execution ===") - async with NBCNewsAPITest() as api: - request = { - "urls": "https://www.nbcnews.com/business", - "priority": 8, - "js_code": [ - "const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More')); loadMoreButton && loadMoreButton.click();" - ], - "wait_for": "article.tease-card:nth-child(10)", - "crawler_params": {"headless": True}, - } - task_id = await api.submit_crawl(request) - result = await api.wait_for_task(task_id) - print(f"JS execution result length: {len(result['result']['markdown'])}") - assert result["status"] == "completed" - assert result["result"]["success"] - - -async def test_css_selector(): - print("\n=== Testing CSS Selector ===") - async with NBCNewsAPITest() as api: - request = { - "urls": "https://www.nbcnews.com/business", - "priority": 7, - "css_selector": ".wide-tease-item__description", - } - task_id = await api.submit_crawl(request) - result = await api.wait_for_task(task_id) - print(f"CSS selector result length: {len(result['result']['markdown'])}") - assert result["status"] == "completed" - assert result["result"]["success"] - - -async def test_structured_extraction(): - print("\n=== Testing Structured Extraction ===") - async with NBCNewsAPITest() as api: - schema = { - "name": "NBC News Articles", - "baseSelector": "article.tease-card", - "fields": [ - {"name": "title", "selector": "h2", "type": "text"}, - { - "name": "description", - "selector": ".tease-card__description", - "type": "text", - }, - { - "name": "link", - "selector": "a", - "type": "attribute", - "attribute": "href", - }, - ], - } - - request = { - "urls": "https://www.nbcnews.com/business", - "priority": 9, - "extraction_config": {"type": "json_css", "params": {"schema": schema}}, - } - task_id = await api.submit_crawl(request) - result = await api.wait_for_task(task_id) - extracted = json.loads(result["result"]["extracted_content"]) - print(f"Extracted {len(extracted)} articles") - assert result["status"] == "completed" - assert result["result"]["success"] - assert len(extracted) > 0 - - -async def test_batch_crawl(): - print("\n=== Testing Batch Crawl ===") - async with NBCNewsAPITest() as api: - request = { - "urls": [ - "https://www.nbcnews.com/business", - "https://www.nbcnews.com/business/consumer", - "https://www.nbcnews.com/business/economy", - ], - "priority": 6, - "crawler_params": {"headless": True}, - } - task_id = await api.submit_crawl(request) - result = await api.wait_for_task(task_id) - print(f"Batch crawl completed, got {len(result['results'])} results") - assert result["status"] == "completed" - assert "results" in result - assert len(result["results"]) == 3 - - -async def test_llm_extraction(): - print("\n=== Testing LLM Extraction with Ollama ===") - async with NBCNewsAPITest() as api: - schema = { - "type": "object", - "properties": { - "article_title": { - "type": "string", - "description": "The main title of the news article", - }, - "summary": { - "type": "string", - "description": "A brief summary of the article content", - }, - "main_topics": { - "type": "array", - "items": {"type": "string"}, - "description": "Main topics or themes discussed in the article", - }, - }, - "required": ["article_title", "summary", "main_topics"], - } - - request = { - "urls": "https://www.nbcnews.com/business", - "priority": 8, - "extraction_config": { - "type": "llm", - "params": { - "provider": "openai/gpt-4o-mini", - "api_key": os.getenv("OLLAMA_API_KEY"), - "schema": schema, - "extraction_type": "schema", - "instruction": """Extract the main article information including title, a brief summary, and main topics discussed. - Focus on the primary business news article on the page.""", - }, - }, - "crawler_params": {"headless": True, "word_count_threshold": 1}, - } - - task_id = await api.submit_crawl(request) - result = await api.wait_for_task(task_id) - - if result["status"] == "completed": - extracted = json.loads(result["result"]["extracted_content"]) - print("Extracted article analysis:") - print(json.dumps(extracted, indent=2)) - - assert result["status"] == "completed" - assert result["result"]["success"] - - -async def test_screenshot(): - print("\n=== Testing Screenshot ===") - async with NBCNewsAPITest() as api: - request = { - "urls": "https://www.nbcnews.com/business", - "priority": 5, - "screenshot": True, - "crawler_params": {"headless": True}, - } - task_id = await api.submit_crawl(request) - result = await api.wait_for_task(task_id) - print("Screenshot captured:", bool(result["result"]["screenshot"])) - assert result["status"] == "completed" - assert result["result"]["success"] - assert result["result"]["screenshot"] is not None - - -async def test_priority_handling(): - print("\n=== Testing Priority Handling ===") - async with NBCNewsAPITest() as api: - # Submit low priority task first - low_priority = { - "urls": "https://www.nbcnews.com/business", - "priority": 1, - "crawler_params": {"headless": True}, - } - low_task_id = await api.submit_crawl(low_priority) - - # Submit high priority task - high_priority = { - "urls": "https://www.nbcnews.com/business/consumer", - "priority": 10, - "crawler_params": {"headless": True}, - } - high_task_id = await api.submit_crawl(high_priority) - - # Get both results - high_result = await api.wait_for_task(high_task_id) - low_result = await api.wait_for_task(low_task_id) - - print("Both tasks completed") - assert high_result["status"] == "completed" - assert low_result["status"] == "completed" - - -async def main(): - try: - # Start with health check - async with NBCNewsAPITest() as api: - health = await api.check_health() - print("Server health:", health) - - # Run all tests - # await test_basic_crawl() - # await test_js_execution() - # await test_css_selector() - # await test_structured_extraction() - await test_llm_extraction() - # await test_batch_crawl() - # await test_screenshot() - # await test_priority_handling() - - except Exception as e: - print(f"Test failed: {str(e)}") - raise - - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/tests/test_scraping_strategy.py b/tests/test_scraping_strategy.py index df4628540..54be5de0d 100644 --- a/tests/test_scraping_strategy.py +++ b/tests/test_scraping_strategy.py @@ -1,26 +1,29 @@ -import nest_asyncio +import sys -nest_asyncio.apply() +import pytest -import asyncio from crawl4ai import ( AsyncWebCrawler, + CacheMode, CrawlerRunConfig, LXMLWebScrapingStrategy, - CacheMode, ) -async def main(): +@pytest.mark.asyncio +async def test_scraping_strategy(): config = CrawlerRunConfig( cache_mode=CacheMode.BYPASS, scraping_strategy=LXMLWebScrapingStrategy(), # Faster alternative to default BeautifulSoup ) async with AsyncWebCrawler() as crawler: result = await crawler.arun(url="https://example.com", config=config) - print(f"Success: {result.success}") - print(f"Markdown length: {len(result.markdown.raw_markdown)}") + assert result.success + assert result.markdown + assert result.markdown.raw_markdown if __name__ == "__main__": - asyncio.run(main()) + import subprocess + + sys.exit(subprocess.call(["pytest", *sys.argv[1:], sys.argv[0]])) diff --git a/tests/test_web_crawler.py b/tests/test_web_crawler.py index b84531924..f684f0acb 100644 --- a/tests/test_web_crawler.py +++ b/tests/test_web_crawler.py @@ -1,17 +1,23 @@ -import unittest, os -from crawl4ai import LLMConfig -from crawl4ai.web_crawler import WebCrawler +import os +import sys +import unittest + +import pytest + +from crawl4ai import CacheMode +from crawl4ai.async_configs import LLMConfig from crawl4ai.chunking_strategy import ( - RegexChunking, FixedLengthWordChunking, + RegexChunking, SlidingWindowChunking, + TopicSegmentationChunking, ) from crawl4ai.extraction_strategy import ( CosineStrategy, LLMExtractionStrategy, - TopicExtractionStrategy, NoExtractionStrategy, ) +from crawl4ai.legacy.web_crawler import WebCrawler class TestWebCrawler(unittest.TestCase): @@ -28,13 +34,16 @@ def test_run_default_strategies(self): word_count_threshold=5, chunking_strategy=RegexChunking(), extraction_strategy=CosineStrategy(), - bypass_cache=True, + cache_mode=CacheMode.BYPASS, + warmup=False, ) self.assertTrue( result.success, "Failed to crawl and extract using default strategies" ) def test_run_different_strategies(self): + if not os.getenv("OPENAI_API_KEY"): + self.skipTest("Skipping env OPENAI_API_KEY not set") url = "https://www.nbcnews.com/business" # Test with FixedLengthWordChunking and LLMExtractionStrategy @@ -45,7 +54,8 @@ def test_run_different_strategies(self): extraction_strategy=LLMExtractionStrategy( llm_config=LLMConfig(provider="openai/gpt-3.5-turbo", api_token=os.getenv("OPENAI_API_KEY")) ), - bypass_cache=True, + cache_mode=CacheMode.BYPASS, + warmup=False ) self.assertTrue( result.success, @@ -56,9 +66,11 @@ def test_run_different_strategies(self): result = self.crawler.run( url=url, word_count_threshold=5, - chunking_strategy=SlidingWindowChunking(window_size=100, step=50), - extraction_strategy=TopicExtractionStrategy(num_keywords=5), - bypass_cache=True, + chunking_strategy=TopicSegmentationChunking( + window_size=100, step=50, num_keywords=5 + ), + cache_mode=CacheMode.BYPASS, + warmup=False, ) self.assertTrue( result.success, @@ -66,37 +78,46 @@ def test_run_different_strategies(self): ) def test_invalid_url(self): - with self.assertRaises(Exception) as context: - self.crawler.run(url="invalid_url", bypass_cache=True) - self.assertIn("Invalid URL", str(context.exception)) + result = self.crawler.run( + url="invalid_url", cache_mode=CacheMode.BYPASS, warmup=False + ) + self.assertFalse(result.success, "Extraction should fail with invalid URL") + msg = "" if not result.error_message else result.error_message + self.assertTrue("invalid argument" in msg) def test_unsupported_extraction_strategy(self): - with self.assertRaises(Exception) as context: - self.crawler.run( - url="https://www.nbcnews.com/business", - extraction_strategy="UnsupportedStrategy", - bypass_cache=True, - ) - self.assertIn("Unsupported extraction strategy", str(context.exception)) + result = self.crawler.run( + url="https://www.nbcnews.com/business", + extraction_strategy="UnsupportedStrategy", # pyright: ignore[reportArgumentType] + cache_mode=CacheMode.BYPASS, + ) + self.assertFalse( + result.success, "Extraction should fail with unsupported strategy" + ) + self.assertEqual("Unsupported extraction strategy", result.error_message) + @pytest.mark.skip("Skipping InvalidCSSSelectorError is no longer raised") def test_invalid_css_selector(self): - with self.assertRaises(ValueError) as context: - self.crawler.run( - url="https://www.nbcnews.com/business", - css_selector="invalid_selector", - bypass_cache=True, - ) - self.assertIn("Invalid CSS selector", str(context.exception)) + result = self.crawler.run( + url="https://www.nbcnews.com/business", + css_selector="invalid_selector", + cache_mode=CacheMode.BYPASS, + warmup=False + ) + self.assertFalse( + result.success, "Extraction should fail with invalid CSS selector" + ) + self.assertEqual("Invalid CSS selector", result.error_message) def test_crawl_with_cache_and_bypass_cache(self): url = "https://www.nbcnews.com/business" # First crawl with cache enabled - result = self.crawler.run(url=url, bypass_cache=False) + result = self.crawler.run(url=url, bypass_cache=False, warmup=False) self.assertTrue(result.success, "Failed to crawl and cache the result") - # Second crawl with bypass_cache=True - result = self.crawler.run(url=url, bypass_cache=True) + # Second crawl with cache_mode=CacheMode.BYPASS + result = self.crawler.run(url=url, cache_mode=CacheMode.BYPASS, warmup=False) self.assertTrue(result.success, "Failed to bypass cache and fetch fresh data") def test_fetch_multiple_pages(self): @@ -108,7 +129,8 @@ def test_fetch_multiple_pages(self): word_count_threshold=5, chunking_strategy=RegexChunking(), extraction_strategy=CosineStrategy(), - bypass_cache=True, + cache_mode=CacheMode.BYPASS, + warmup=False ) results.append(result) @@ -124,7 +146,8 @@ def test_run_fixed_length_word_chunking_and_no_extraction(self): word_count_threshold=5, chunking_strategy=FixedLengthWordChunking(chunk_size=100), extraction_strategy=NoExtractionStrategy(), - bypass_cache=True, + cache_mode=CacheMode.BYPASS, + warmup=False, ) self.assertTrue( result.success, @@ -137,7 +160,8 @@ def test_run_sliding_window_and_no_extraction(self): word_count_threshold=5, chunking_strategy=SlidingWindowChunking(window_size=100, step=50), extraction_strategy=NoExtractionStrategy(), - bypass_cache=True, + cache_mode=CacheMode.BYPASS, + warmup=False ) self.assertTrue( result.success, @@ -146,4 +170,6 @@ def test_run_sliding_window_and_no_extraction(self): if __name__ == "__main__": - unittest.main() + import subprocess + + sys.exit(subprocess.call(["pytest", *sys.argv[1:], sys.argv[0]]))