Short but important section header description.
Long paragraph with sufficient words to meet the minimum threshold. This paragraph continues with more text to ensure we have enough content for proper testing. We need to make sure this has enough words to pass our filters and be considered valid content for extraction purposes.
@@ -78,7 +77,9 @@ def test_user_query_override(self, basic_html):
# Access internal state to verify query usage
soup = BeautifulSoup(basic_html, "lxml")
- extracted_query = filter.extract_page_query(soup.find("head"))
+ head: Union[PageElement, Tag, None] = soup.find("head")
+ assert isinstance(head, Tag)
+ extracted_query = filter.extract_page_query(soup, head)
assert extracted_query == user_query
assert "Test description" not in extracted_query
@@ -89,7 +90,7 @@ def test_header_extraction(self, wiki_html):
contents = filter.filter_content(wiki_html)
combined_content = " ".join(contents).lower()
- assert "section 1" in combined_content, "Should include section header"
+ assert "section one" in combined_content, "Should include section header"
assert "article title" in combined_content, "Should include main title"
def test_no_metadata_fallback(self, no_meta_html):
@@ -106,7 +107,6 @@ def test_empty_input(self):
"""Test handling of empty input"""
filter = BM25ContentFilter()
assert filter.filter_content("") == []
- assert filter.filter_content(None) == []
def test_malformed_html(self):
"""Test handling of malformed HTML"""
@@ -142,7 +142,7 @@ def test_large_content(self):
"""Test handling of large content blocks"""
large_html = f"""
-
{'Test content. ' * 1000}
+
{'Test content. ' * 1000}
"""
filter = BM25ContentFilter()
@@ -180,4 +180,6 @@ def test_performance(self, basic_html):
if __name__ == "__main__":
- pytest.main([__file__])
+ import subprocess
+
+ sys.exit(subprocess.call(["pytest", *sys.argv[1:], sys.argv[0]]))
diff --git a/tests/async/test_content_filter_prune.py b/tests/async/test_content_filter_prune.py
index 1f75a9e1e..d443a5cd0 100644
--- a/tests/async/test_content_filter_prune.py
+++ b/tests/async/test_content_filter_prune.py
@@ -1,8 +1,6 @@
-import os, sys
-import pytest
+import sys
-parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
-sys.path.append(parent_dir)
+import pytest
from crawl4ai.content_filter_strategy import PruningContentFilter
@@ -86,11 +84,11 @@ def test_min_word_threshold(self, mixed_content_html):
def test_threshold_types(self, basic_html):
"""Test fixed vs dynamic thresholds"""
- fixed_filter = PruningContentFilter(threshold_type="fixed", threshold=0.48)
- dynamic_filter = PruningContentFilter(threshold_type="dynamic", threshold=0.45)
+ fixed_filter = PruningContentFilter(threshold_type="fixed", threshold=1.1)
+ dynamic_filter = PruningContentFilter(threshold_type="dynamic", threshold=1.1)
- fixed_contents = fixed_filter.filter_content(basic_html)
- dynamic_contents = dynamic_filter.filter_content(basic_html)
+ fixed_contents = "".join(fixed_filter.filter_content(basic_html))
+ dynamic_contents = "".join(dynamic_filter.filter_content(basic_html))
assert len(fixed_contents) != len(
dynamic_contents
@@ -120,7 +118,6 @@ def test_empty_input(self):
"""Test handling of empty input"""
filter = PruningContentFilter()
assert filter.filter_content("") == []
- assert filter.filter_content(None) == []
def test_malformed_html(self):
"""Test handling of malformed HTML"""
@@ -167,4 +164,6 @@ def test_consistent_output(self, basic_html):
if __name__ == "__main__":
- pytest.main([__file__])
+ import subprocess
+
+ sys.exit(subprocess.call(["pytest", *sys.argv[1:], sys.argv[0]]))
diff --git a/tests/async/test_content_scraper_strategy.py b/tests/async/test_content_scraper_strategy.py
index e6caf2405..7db09ae9a 100644
--- a/tests/async/test_content_scraper_strategy.py
+++ b/tests/async/test_content_scraper_strategy.py
@@ -1,219 +1,232 @@
-import os
+import csv
import sys
import time
-import csv
+from dataclasses import asdict, dataclass
+from functools import lru_cache
+from pathlib import Path
+from typing import Any, List
+
+import pytest
+from _pytest.mark import ParameterSet
from tabulate import tabulate
-from dataclasses import dataclass
-from typing import List
-parent_dir = os.path.dirname(
- os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+from crawl4ai.content_scraping_strategy import (
+ ContentScrapingStrategy,
+ WebScrapingStrategy,
)
-sys.path.append(parent_dir)
-__location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__)))
-
-from crawl4ai.content_scraping_strategy import WebScrapingStrategy
from crawl4ai.content_scraping_strategy import (
WebScrapingStrategy as WebScrapingStrategyCurrent,
)
-# from crawl4ai.content_scrapping_strategy_current import WebScrapingStrategy as WebScrapingStrategyCurrent
+from crawl4ai.models import ScrapingResult
@dataclass
-class TestResult:
+class Result:
name: str
+ strategy: str
success: bool
images: int
internal_links: int
external_links: int
- markdown_length: int
+ cleaned_html_length: int
execution_time: float
-class StrategyTester:
- def __init__(self):
- self.new_scraper = WebScrapingStrategy()
- self.current_scraper = WebScrapingStrategyCurrent()
- with open(__location__ + "/sample_wikipedia.html", "r", encoding="utf-8") as f:
- self.WIKI_HTML = f.read()
- self.results = {"new": [], "current": []}
-
- def run_test(self, name: str, **kwargs) -> tuple[TestResult, TestResult]:
- results = []
- for scraper in [self.new_scraper, self.current_scraper]:
- start_time = time.time()
- result = scraper._get_content_of_website_optimized(
- url="https://en.wikipedia.org/wiki/Test", html=self.WIKI_HTML, **kwargs
- )
- execution_time = time.time() - start_time
-
- test_result = TestResult(
- name=name,
- success=result["success"],
- images=len(result["media"]["images"]),
- internal_links=len(result["links"]["internal"]),
- external_links=len(result["links"]["external"]),
- markdown_length=len(result["markdown"]),
- execution_time=execution_time,
- )
- results.append(test_result)
-
- return results[0], results[1] # new, current
-
- def run_all_tests(self):
- test_cases = [
- ("Basic Extraction", {}),
- ("Exclude Tags", {"excluded_tags": ["table", "div.infobox", "div.navbox"]}),
- ("Word Threshold", {"word_count_threshold": 50}),
- ("CSS Selector", {"css_selector": "div.mw-parser-output > p"}),
- (
- "Link Exclusions",
- {
- "exclude_external_links": True,
- "exclude_social_media_links": True,
- "exclude_domains": ["facebook.com", "twitter.com"],
- },
- ),
- (
- "Media Handling",
- {
- "exclude_external_images": True,
- "image_description_min_word_threshold": 20,
- },
- ),
- ("Text Only", {"only_text": True, "remove_forms": True}),
- ("HTML Cleaning", {"clean_html": True, "keep_data_attributes": True}),
- (
- "HTML2Text Options",
- {
- "html2text": {
- "skip_internal_links": True,
- "single_line_break": True,
- "mark_code": True,
- "preserve_tags": ["pre", "code"],
- }
- },
- ),
+@lru_cache
+@pytest.fixture
+def wiki_html() -> str:
+ file_path: Path = Path(__file__).parent / "sample_wikipedia.html"
+ with file_path.open("r", encoding="utf-8") as f:
+ return f.read()
+
+
+results: List[Result] = []
+
+
+def print_comparison_table():
+ """Print comparison table of results."""
+ if not results:
+ return
+
+ table_data = []
+ headers = [
+ "Test Name",
+ "Strategy",
+ "Success",
+ "Images",
+ "Internal Links",
+ "External Links",
+ "Cleaned HTML Length",
+ "Time (s)",
+ ]
+
+ all_results: List[tuple[str, Result, Result]] = []
+ new_results = [result for result in results if result.strategy == "new"]
+ current_results = [result for result in results if result.strategy == "current"]
+ for new_result in new_results:
+ for current_result in current_results:
+ if new_result.name == current_result.name:
+ all_results.append((new_result.name, new_result, current_result))
+
+ for name, new_result, current_result in all_results:
+ # Check for differences
+ differences = []
+ if new_result.images != current_result.images:
+ differences.append("images")
+ if new_result.internal_links != current_result.internal_links:
+ differences.append("internal_links")
+ if new_result.external_links != current_result.external_links:
+ differences.append("external_links")
+ if new_result.cleaned_html_length != current_result.cleaned_html_length:
+ differences.append("cleaned_html")
+
+ # Add row for new strategy
+ new_row = [
+ name,
+ "New",
+ new_result.success,
+ new_result.images,
+ new_result.internal_links,
+ new_result.external_links,
+ new_result.cleaned_html_length,
+ f"{new_result.execution_time:.3f}",
+ ]
+ table_data.append(new_row)
+
+ # Add row for current strategy
+ current_row = [
+ "",
+ "Current",
+ current_result.success,
+ current_result.images,
+ current_result.internal_links,
+ current_result.external_links,
+ current_result.cleaned_html_length,
+ f"{current_result.execution_time:.3f}",
]
+ table_data.append(current_row)
- all_results = []
- for name, kwargs in test_cases:
- try:
- new_result, current_result = self.run_test(name, **kwargs)
- all_results.append((name, new_result, current_result))
- except Exception as e:
- print(f"Error in {name}: {str(e)}")
-
- self.save_results_to_csv(all_results)
- self.print_comparison_table(all_results)
-
- def save_results_to_csv(self, all_results: List[tuple]):
- csv_file = os.path.join(__location__, "strategy_comparison_results.csv")
- with open(csv_file, "w", newline="") as f:
- writer = csv.writer(f)
- writer.writerow(
- [
- "Test Name",
- "Strategy",
- "Success",
- "Images",
- "Internal Links",
- "External Links",
- "Markdown Length",
- "Execution Time",
- ]
+ # Add difference summary if any
+ if differences:
+ table_data.append(
+ ["", "⚠️ Differences", ", ".join(differences), "", "", "", "", ""]
)
- for name, new_result, current_result in all_results:
- writer.writerow(
- [
- name,
- "New",
- new_result.success,
- new_result.images,
- new_result.internal_links,
- new_result.external_links,
- new_result.markdown_length,
- f"{new_result.execution_time:.3f}",
- ]
- )
- writer.writerow(
- [
- name,
- "Current",
- current_result.success,
- current_result.images,
- current_result.internal_links,
- current_result.external_links,
- current_result.markdown_length,
- f"{current_result.execution_time:.3f}",
- ]
- )
-
- def print_comparison_table(self, all_results: List[tuple]):
- table_data = []
- headers = [
- "Test Name",
- "Strategy",
- "Success",
- "Images",
- "Internal Links",
- "External Links",
- "Markdown Length",
- "Time (s)",
- ]
-
- for name, new_result, current_result in all_results:
- # Check for differences
- differences = []
- if new_result.images != current_result.images:
- differences.append("images")
- if new_result.internal_links != current_result.internal_links:
- differences.append("internal_links")
- if new_result.external_links != current_result.external_links:
- differences.append("external_links")
- if new_result.markdown_length != current_result.markdown_length:
- differences.append("markdown")
-
- # Add row for new strategy
- new_row = [
- name,
- "New",
- new_result.success,
- new_result.images,
- new_result.internal_links,
- new_result.external_links,
- new_result.markdown_length,
- f"{new_result.execution_time:.3f}",
- ]
- table_data.append(new_row)
-
- # Add row for current strategy
- current_row = [
- "",
- "Current",
- current_result.success,
- current_result.images,
- current_result.internal_links,
- current_result.external_links,
- current_result.markdown_length,
- f"{current_result.execution_time:.3f}",
+ # Add empty row for better readability
+ table_data.append([""] * len(headers))
+
+ print("\nStrategy Comparison Results:")
+ print(tabulate(table_data, headers=headers, tablefmt="grid"))
+
+
+def write_results_to_csv():
+ """Write results to CSV and print comparison table."""
+ if not results:
+ return
+ csv_file: Path = Path(__file__).parent / "output/strategy_comparison_results.csv"
+ with csv_file.open("w", newline="") as f:
+ writer = csv.writer(f)
+ writer.writerow(
+ [
+ "Test Name",
+ "Strategy",
+ "Success",
+ "Images",
+ "Internal Links",
+ "External Links",
+ "Cleaned HTML Length",
+ "Execution Time",
]
- table_data.append(current_row)
-
- # Add difference summary if any
- if differences:
- table_data.append(
- ["", "⚠️ Differences", ", ".join(differences), "", "", "", "", ""]
+ )
+
+ for result in results:
+ writer.writerow(asdict(result))
+
+
+def scrapper_params() -> List[ParameterSet]:
+ test_cases = [
+ ("Basic Extraction", {}),
+ ("Exclude Tags", {"excluded_tags": ["table", "div.infobox", "div.navbox"]}),
+ ("Word Threshold", {"word_count_threshold": 50}),
+ ("CSS Selector", {"css_selector": "div.mw-parser-output > p"}),
+ (
+ "Link Exclusions",
+ {
+ "exclude_external_links": True,
+ "exclude_social_media_links": True,
+ "exclude_domains": ["facebook.com", "twitter.com"],
+ },
+ ),
+ (
+ "Media Handling",
+ {
+ "exclude_external_images": True,
+ "image_description_min_word_threshold": 20,
+ },
+ ),
+ ("Text Only", {"only_text": True, "remove_forms": True}),
+ ("HTML Cleaning", {"clean_html": True, "keep_data_attributes": True}),
+ (
+ "HTML2Text Options",
+ {
+ "html2text": {
+ "skip_internal_links": True,
+ "single_line_break": True,
+ "mark_code": True,
+ "preserve_tags": ["pre", "code"],
+ }
+ },
+ ),
+ ]
+ params: List[ParameterSet] = []
+ for strategy_name, strategy in [
+ ("new", WebScrapingStrategy()),
+ ("current", WebScrapingStrategyCurrent()),
+ ]:
+ for name, kwargs in test_cases:
+ params.append(
+ pytest.param(
+ name,
+ strategy_name,
+ strategy,
+ kwargs,
+ id=f"{name} - {strategy_name}",
)
+ )
- # Add empty row for better readability
- table_data.append([""] * len(headers))
-
- print("\nStrategy Comparison Results:")
- print(tabulate(table_data, headers=headers, tablefmt="grid"))
+ return params
+
+
+@pytest.mark.parametrize("name,strategy_name,strategy,kwargs", scrapper_params())
+def test_strategy(
+ wiki_html: str,
+ name: str,
+ strategy_name: str,
+ strategy: ContentScrapingStrategy,
+ kwargs: dict[str, Any],
+):
+ start_time = time.time()
+ result: ScrapingResult = strategy.scrap(
+ url="https://en.wikipedia.org/wiki/Test", html=wiki_html, **kwargs
+ )
+ assert result.success
+ execution_time = time.time() - start_time
+
+ results.append(
+ Result(
+ name=name,
+ strategy=strategy_name,
+ success=result.success,
+ images=len(result.media.images),
+ internal_links=len(result.links.internal),
+ external_links=len(result.links.external),
+ cleaned_html_length=len(result.cleaned_html),
+ execution_time=execution_time,
+ )
+ )
if __name__ == "__main__":
- tester = StrategyTester()
- tester.run_all_tests()
+ import subprocess
+
+ sys.exit(subprocess.call(["pytest", *sys.argv[1:], sys.argv[0]]))
diff --git a/tests/async/test_crawler_strategy.py b/tests/async/test_crawler_strategy.py
index 337b5aaa8..8501bcd56 100644
--- a/tests/async/test_crawler_strategy.py
+++ b/tests/async/test_crawler_strategy.py
@@ -1,11 +1,8 @@
-import os
import sys
-import pytest
-# Add the parent directory to the Python path
-parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
-sys.path.append(parent_dir)
+import pytest
+from crawl4ai import CacheMode
from crawl4ai.async_webcrawler import AsyncWebCrawler
@@ -15,7 +12,7 @@ async def test_custom_user_agent():
custom_user_agent = "MyCustomUserAgent/1.0"
crawler.crawler_strategy.update_user_agent(custom_user_agent)
url = "https://httpbin.org/user-agent"
- result = await crawler.arun(url=url, bypass_cache=True)
+ result = await crawler.arun(url=url, cache_mode=CacheMode.BYPASS)
assert result.success
assert custom_user_agent in result.html
@@ -26,7 +23,7 @@ async def test_custom_headers():
custom_headers = {"X-Test-Header": "TestValue"}
crawler.crawler_strategy.set_custom_headers(custom_headers)
url = "https://httpbin.org/headers"
- result = await crawler.arun(url=url, bypass_cache=True)
+ result = await crawler.arun(url=url, cache_mode=CacheMode.BYPASS)
assert result.success
assert "X-Test-Header" in result.html
assert "TestValue" in result.html
@@ -37,7 +34,7 @@ async def test_javascript_execution():
async with AsyncWebCrawler(verbose=True) as crawler:
js_code = "document.body.innerHTML = '
Modified by JS
';"
url = "https://www.example.com"
- result = await crawler.arun(url=url, bypass_cache=True, js_code=js_code)
+ result = await crawler.arun(url=url, cache_mode=CacheMode.BYPASS, js_code=js_code)
assert result.success
assert "
Modified by JS
" in result.html
@@ -46,13 +43,13 @@ async def test_javascript_execution():
async def test_hook_execution():
async with AsyncWebCrawler(verbose=True) as crawler:
- async def test_hook(page):
+ async def test_hook(page, **kwargs):
await page.evaluate("document.body.style.backgroundColor = 'red';")
return page
crawler.crawler_strategy.set_hook("after_goto", test_hook)
url = "https://www.example.com"
- result = await crawler.arun(url=url, bypass_cache=True)
+ result = await crawler.arun(url=url, cache_mode=CacheMode.BYPASS)
assert result.success
assert "background-color: red" in result.html
@@ -61,7 +58,7 @@ async def test_hook(page):
async def test_screenshot():
async with AsyncWebCrawler(verbose=True) as crawler:
url = "https://www.example.com"
- result = await crawler.arun(url=url, bypass_cache=True, screenshot=True)
+ result = await crawler.arun(url=url, cache_mode=CacheMode.BYPASS, screenshot=True)
assert result.success
assert result.screenshot
assert isinstance(result.screenshot, str)
@@ -70,4 +67,6 @@ async def test_screenshot():
# Entry point for debugging
if __name__ == "__main__":
- pytest.main([__file__, "-v"])
+ import subprocess
+
+ sys.exit(subprocess.call(["pytest", *sys.argv[1:], sys.argv[0]]))
diff --git a/tests/async/test_database_operations.py b/tests/async/test_database_operations.py
index db0d328ed..3632579f1 100644
--- a/tests/async/test_database_operations.py
+++ b/tests/async/test_database_operations.py
@@ -1,11 +1,8 @@
-import os
import sys
-import pytest
-# Add the parent directory to the Python path
-parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
-sys.path.append(parent_dir)
+import pytest
+from crawl4ai import CacheMode
from crawl4ai.async_webcrawler import AsyncWebCrawler
@@ -14,7 +11,7 @@ async def test_cache_url():
async with AsyncWebCrawler(verbose=True) as crawler:
url = "https://www.example.com"
# First run to cache the URL
- result1 = await crawler.arun(url=url, bypass_cache=True)
+ result1 = await crawler.arun(url=url, cache_mode=CacheMode.BYPASS)
assert result1.success
# Second run to retrieve from cache
@@ -28,11 +25,11 @@ async def test_bypass_cache():
async with AsyncWebCrawler(verbose=True) as crawler:
url = "https://www.python.org"
# First run to cache the URL
- result1 = await crawler.arun(url=url, bypass_cache=True)
+ result1 = await crawler.arun(url=url, cache_mode=CacheMode.BYPASS)
assert result1.success
# Second run bypassing cache
- result2 = await crawler.arun(url=url, bypass_cache=True)
+ result2 = await crawler.arun(url=url, cache_mode=CacheMode.BYPASS)
assert result2.success
assert (
result2.html != result1.html
@@ -42,10 +39,11 @@ async def test_bypass_cache():
@pytest.mark.asyncio
async def test_cache_size():
async with AsyncWebCrawler(verbose=True) as crawler:
+ await crawler.aclear_cache()
initial_size = await crawler.aget_cache_size()
url = "https://www.nbcnews.com/business"
- await crawler.arun(url=url, bypass_cache=True)
+ await crawler.arun(url=url, cache_mode=CacheMode.ENABLED)
new_size = await crawler.aget_cache_size()
assert new_size == initial_size + 1
@@ -55,7 +53,7 @@ async def test_cache_size():
async def test_clear_cache():
async with AsyncWebCrawler(verbose=True) as crawler:
url = "https://www.example.org"
- await crawler.arun(url=url, bypass_cache=True)
+ await crawler.arun(url=url, cache_mode=CacheMode.ENABLED)
initial_size = await crawler.aget_cache_size()
assert initial_size > 0
@@ -69,7 +67,8 @@ async def test_clear_cache():
async def test_flush_cache():
async with AsyncWebCrawler(verbose=True) as crawler:
url = "https://www.example.net"
- await crawler.arun(url=url, bypass_cache=True)
+ result = await crawler.arun(url=url, cache_mode=CacheMode.ENABLED)
+ assert result and result.success
initial_size = await crawler.aget_cache_size()
assert initial_size > 0
@@ -87,4 +86,6 @@ async def test_flush_cache():
# Entry point for debugging
if __name__ == "__main__":
- pytest.main([__file__, "-v"])
+ import subprocess
+
+ sys.exit(subprocess.call(["pytest", *sys.argv[1:], sys.argv[0]]))
diff --git a/tests/async/test_dispatchers.py b/tests/async/test_dispatchers.py
index 99cf4a989..7caed748c 100644
--- a/tests/async/test_dispatchers.py
+++ b/tests/async/test_dispatchers.py
@@ -1,15 +1,18 @@
-import pytest
+import sys
import time
+
+from httpx import codes
+import pytest
+
from crawl4ai import (
AsyncWebCrawler,
BrowserConfig,
+ CacheMode,
+ CrawlerMonitor,
CrawlerRunConfig,
MemoryAdaptiveDispatcher,
- SemaphoreDispatcher,
RateLimiter,
- CrawlerMonitor,
- DisplayMode,
- CacheMode,
+ SemaphoreDispatcher,
)
@@ -37,7 +40,7 @@ class TestDispatchStrategies:
async def test_memory_adaptive_basic(self, browser_config, run_config, test_urls):
async with AsyncWebCrawler(config=browser_config) as crawler:
dispatcher = MemoryAdaptiveDispatcher(
- memory_threshold_percent=70.0, max_session_permit=2, check_interval=0.1
+ memory_threshold_percent=80.0, max_session_permit=2, check_interval=0.1
)
results = await crawler.arun_many(
test_urls, config=run_config, dispatcher=dispatcher
@@ -50,7 +53,7 @@ async def test_memory_adaptive_with_rate_limit(
):
async with AsyncWebCrawler(config=browser_config) as crawler:
dispatcher = MemoryAdaptiveDispatcher(
- memory_threshold_percent=70.0,
+ memory_threshold_percent=80.0,
max_session_permit=2,
check_interval=0.1,
rate_limiter=RateLimiter(
@@ -88,6 +91,7 @@ async def test_semaphore_with_rate_limit(
assert len(results) == len(test_urls)
assert all(r.success for r in results)
+ @pytest.mark.skip(reason="memory_wait_timeout is not a valid MemoryAdaptiveDispatcher parameter")
async def test_memory_adaptive_memory_error(
self, browser_config, run_config, test_urls
):
@@ -140,7 +144,7 @@ async def test_rate_limit_backoff(self, browser_config, run_config):
base_delay=(0.1, 0.2),
max_delay=1.0,
max_retries=2,
- rate_limit_codes=[200], # Force rate limiting for testing
+ rate_limit_codes=[codes.OK], # Force rate limiting for testing
),
)
start_time = time.time()
@@ -151,11 +155,10 @@ async def test_rate_limit_backoff(self, browser_config, run_config):
assert len(results) == len(urls)
assert duration > 1.0 # Ensure rate limiting caused delays
+ @pytest.mark.skip(reason="max_visible_rows is not a valid CrawlerMonitor parameter")
async def test_monitor_integration(self, browser_config, run_config, test_urls):
async with AsyncWebCrawler(config=browser_config) as crawler:
- monitor = CrawlerMonitor(
- max_visible_rows=5, display_mode=DisplayMode.DETAILED
- )
+ monitor = CrawlerMonitor(urls_total=5)
dispatcher = MemoryAdaptiveDispatcher(max_session_permit=2, monitor=monitor)
results = await crawler.arun_many(
test_urls, config=run_config, dispatcher=dispatcher
@@ -163,8 +166,10 @@ async def test_monitor_integration(self, browser_config, run_config, test_urls):
assert len(results) == len(test_urls)
# Check monitor stats
assert len(monitor.stats) == len(test_urls)
- assert all(stat.end_time is not None for stat in monitor.stats.values())
+ assert all(stat["end_time"] is not None for stat in monitor.stats.values())
if __name__ == "__main__":
- pytest.main([__file__, "-v", "--asyncio-mode=auto"])
+ import subprocess
+
+ sys.exit(subprocess.call(["pytest", *sys.argv[1:], sys.argv[0]]))
diff --git a/tests/async/test_edge_cases.py b/tests/async/test_edge_cases.py
index d3adb53c5..8a1cc5274 100644
--- a/tests/async/test_edge_cases.py
+++ b/tests/async/test_edge_cases.py
@@ -1,63 +1,67 @@
-import os
+import asyncio
import re
import sys
+
import pytest
from bs4 import BeautifulSoup
-import asyncio
-
-# Add the parent directory to the Python path
-parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
-sys.path.append(parent_dir)
+from playwright.async_api import Page, BrowserContext
+from crawl4ai import CacheMode
from crawl4ai.async_webcrawler import AsyncWebCrawler
-# @pytest.mark.asyncio
-# async def test_large_content_page():
-# async with AsyncWebCrawler(verbose=True) as crawler:
-# url = "https://en.wikipedia.org/wiki/List_of_largest_known_stars" # A page with a large table
-# result = await crawler.arun(url=url, bypass_cache=True)
-# assert result.success
-# assert len(result.html) > 1000000 # Expecting more than 1MB of content
-
-# @pytest.mark.asyncio
-# async def test_minimal_content_page():
-# async with AsyncWebCrawler(verbose=True) as crawler:
-# url = "https://example.com" # A very simple page
-# result = await crawler.arun(url=url, bypass_cache=True)
-# assert result.success
-# assert len(result.html) < 10000 # Expecting less than 10KB of content
-
-# @pytest.mark.asyncio
-# async def test_single_page_application():
-# async with AsyncWebCrawler(verbose=True) as crawler:
-# url = "https://reactjs.org/" # React's website is a SPA
-# result = await crawler.arun(url=url, bypass_cache=True)
-# assert result.success
-# assert "react" in result.html.lower()
-
-# @pytest.mark.asyncio
-# async def test_page_with_infinite_scroll():
-# async with AsyncWebCrawler(verbose=True) as crawler:
-# url = "https://news.ycombinator.com/" # Hacker News has infinite scroll
-# result = await crawler.arun(url=url, bypass_cache=True)
-# assert result.success
-# assert "hacker news" in result.html.lower()
-
-# @pytest.mark.asyncio
-# async def test_page_with_heavy_javascript():
-# async with AsyncWebCrawler(verbose=True) as crawler:
-# url = "https://www.airbnb.com/" # Airbnb uses a lot of JavaScript
-# result = await crawler.arun(url=url, bypass_cache=True)
-# assert result.success
-# assert "airbnb" in result.html.lower()
-
-# @pytest.mark.asyncio
-# async def test_page_with_mixed_content():
-# async with AsyncWebCrawler(verbose=True) as crawler:
-# url = "https://github.com/" # GitHub has a mix of static and dynamic content
-# result = await crawler.arun(url=url, bypass_cache=True)
-# assert result.success
-# assert "github" in result.html.lower()
+
+@pytest.mark.asyncio
+async def test_large_content_page():
+ async with AsyncWebCrawler(verbose=True) as crawler:
+ url = "https://en.wikipedia.org/wiki/List_of_largest_known_stars" # A page with a large table
+ result = await crawler.arun(url=url, cache_mode=CacheMode.BYPASS)
+ assert result.success
+ assert len(result.html) > 1000000 # Expecting more than 1MB of content
+
+
+@pytest.mark.asyncio
+async def test_minimal_content_page():
+ async with AsyncWebCrawler(verbose=True) as crawler:
+ url = "https://example.com" # A very simple page
+ result = await crawler.arun(url=url, cache_mode=CacheMode.BYPASS)
+ assert result.success
+ assert len(result.html) < 10000 # Expecting less than 10KB of content
+
+
+@pytest.mark.asyncio
+async def test_single_page_application():
+ async with AsyncWebCrawler(verbose=True) as crawler:
+ url = "https://reactjs.org/" # React's website is a SPA
+ result = await crawler.arun(url=url, cache_mode=CacheMode.BYPASS)
+ assert result.success
+ assert "react" in result.html.lower()
+
+
+@pytest.mark.asyncio
+async def test_page_with_infinite_scroll():
+ async with AsyncWebCrawler(verbose=True) as crawler:
+ url = "https://news.ycombinator.com/" # Hacker News has infinite scroll
+ result = await crawler.arun(url=url, cache_mode=CacheMode.BYPASS)
+ assert result.success
+ assert "hacker news" in result.html.lower()
+
+
+@pytest.mark.asyncio
+async def test_page_with_heavy_javascript():
+ async with AsyncWebCrawler(verbose=True) as crawler:
+ url = "https://www.airbnb.com/" # Airbnb uses a lot of JavaScript
+ result = await crawler.arun(url=url, cache_mode=CacheMode.BYPASS)
+ assert result.success
+ assert "airbnb" in result.html.lower()
+
+
+@pytest.mark.asyncio
+async def test_page_with_mixed_content():
+ async with AsyncWebCrawler(verbose=True) as crawler:
+ url = "https://github.com/" # GitHub has a mix of static and dynamic content
+ result = await crawler.arun(url=url, cache_mode=CacheMode.BYPASS)
+ assert result.success
+ assert "github" in result.html.lower()
# Add this test to your existing test file
@@ -65,13 +69,13 @@
async def test_typescript_commits_multi_page():
first_commit = ""
- async def on_execution_started(page):
+ async def on_execution_started(page: Page, context: BrowserContext, **kwargs):
nonlocal first_commit
try:
- # Check if the page firct commit h4 text is different from the first commit (use document.querySelector('li.Box-sc-g0xbh4-0 h4'))
+ # Check if the page first commit h4 text is different from the first commit (use document.querySelector('div.Box-sc-g0xbh4-0 h4'))
while True:
- await page.wait_for_selector("li.Box-sc-g0xbh4-0 h4")
- commit = await page.query_selector("li.Box-sc-g0xbh4-0 h4")
+ await page.wait_for_selector("div.Box-sc-g0xbh4-0 h4")
+ commit = await page.query_selector("div.Box-sc-g0xbh4-0 h4")
commit = await commit.evaluate("(element) => element.textContent")
commit = re.sub(r"\s+", "", commit)
if commit and commit != first_commit:
@@ -97,16 +101,14 @@ async def on_execution_started(page):
result = await crawler.arun(
url=url, # Only use URL for the first page
session_id=session_id,
- css_selector="li.Box-sc-g0xbh4-0",
- js=js_next_page
- if page > 0
- else None, # Don't click 'next' on the first page
- bypass_cache=True,
+ css_selector="div.Box-sc-g0xbh4-0",
+ js_code=js_next_page if page > 0 else None, # Don't click 'next' on the first page
+ cache_mode=CacheMode.BYPASS,
js_only=page > 0, # Use js_only for subsequent pages
)
assert result.success, f"Failed to crawl page {page + 1}"
-
+ assert result.cleaned_html, f"No cleaned HTML found for page {page + 1}"
# Parse the HTML and extract commits
soup = BeautifulSoup(result.cleaned_html, "html.parser")
commits = soup.select("li")
@@ -121,13 +123,13 @@ async def on_execution_started(page):
await crawler.crawler_strategy.kill_session(session_id)
# Assertions
- assert (
- len(all_commits) >= 90
- ), f"Expected at least 90 commits, but got {len(all_commits)}"
+ assert len(all_commits) >= 90, f"Expected at least 90 commits, but got {len(all_commits)}"
print(f"Successfully crawled {len(all_commits)} commits across 3 pages")
# Entry point for debugging
if __name__ == "__main__":
- pytest.main([__file__, "-v"])
+ import subprocess
+
+ sys.exit(subprocess.call(["pytest", *sys.argv[1:], sys.argv[0]]))
diff --git a/tests/async/test_error_handling.py b/tests/async/test_error_handling.py
index ae4af6c84..ca06722da 100644
--- a/tests/async/test_error_handling.py
+++ b/tests/async/test_error_handling.py
@@ -1,78 +1,83 @@
-# import os
-# import sys
-# import pytest
-# import asyncio
-
-# # Add the parent directory to the Python path
-# parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
-# sys.path.append(parent_dir)
-
-# from crawl4ai.async_webcrawler import AsyncWebCrawler
-# from crawl4ai.utils import InvalidCSSSelectorError
-
-# class AsyncCrawlerWrapper:
-# def __init__(self):
-# self.crawler = None
-
-# async def setup(self):
-# self.crawler = AsyncWebCrawler(verbose=True)
-# await self.crawler.awarmup()
-
-# async def cleanup(self):
-# if self.crawler:
-# await self.crawler.aclear_cache()
-
-# @pytest.fixture(scope="module")
-# def crawler_wrapper():
-# wrapper = AsyncCrawlerWrapper()
-# asyncio.get_event_loop().run_until_complete(wrapper.setup())
-# yield wrapper
-# asyncio.get_event_loop().run_until_complete(wrapper.cleanup())
-
-# @pytest.mark.asyncio
-# async def test_network_error(crawler_wrapper):
-# url = "https://www.nonexistentwebsite123456789.com"
-# result = await crawler_wrapper.crawler.arun(url=url, bypass_cache=True)
-# assert not result.success
-# assert "Failed to crawl" in result.error_message
-
-# # @pytest.mark.asyncio
-# # async def test_timeout_error(crawler_wrapper):
-# # # Simulating a timeout by using a very short timeout value
-# # url = "https://www.nbcnews.com/business"
-# # result = await crawler_wrapper.crawler.arun(url=url, bypass_cache=True, timeout=0.001)
-# # assert not result.success
-# # assert "timeout" in result.error_message.lower()
-
-# # @pytest.mark.asyncio
-# # async def test_invalid_css_selector(crawler_wrapper):
-# # url = "https://www.nbcnews.com/business"
-# # with pytest.raises(InvalidCSSSelectorError):
-# # await crawler_wrapper.crawler.arun(url=url, bypass_cache=True, css_selector="invalid>>selector")
-
-# # @pytest.mark.asyncio
-# # async def test_js_execution_error(crawler_wrapper):
-# # url = "https://www.nbcnews.com/business"
-# # invalid_js = "This is not valid JavaScript code;"
-# # result = await crawler_wrapper.crawler.arun(url=url, bypass_cache=True, js=invalid_js)
-# # assert not result.success
-# # assert "JavaScript" in result.error_message
-
-# # @pytest.mark.asyncio
-# # async def test_empty_page(crawler_wrapper):
-# # # Use a URL that typically returns an empty page
-# # url = "http://example.com/empty"
-# # result = await crawler_wrapper.crawler.arun(url=url, bypass_cache=True)
-# # assert result.success # The crawl itself should succeed
-# # assert not result.markdown.strip() # The markdown content should be empty or just whitespace
-
-# # @pytest.mark.asyncio
-# # async def test_rate_limiting(crawler_wrapper):
-# # # Simulate rate limiting by making multiple rapid requests
-# # url = "https://www.nbcnews.com/business"
-# # results = await asyncio.gather(*[crawler_wrapper.crawler.arun(url=url, bypass_cache=True) for _ in range(10)])
-# # assert any(not result.success and "rate limit" in result.error_message.lower() for result in results)
-
-# # Entry point for debugging
-# if __name__ == "__main__":
-# pytest.main([__file__, "-v"])
+import asyncio
+import sys
+
+import pytest
+import pytest_asyncio
+
+from crawl4ai import CacheMode
+from crawl4ai.async_webcrawler import AsyncWebCrawler
+from crawl4ai.utils import InvalidCSSSelectorError
+
+
+@pytest_asyncio.fixture
+async def crawler():
+ async with AsyncWebCrawler(verbose=True, warmup=False) as crawler:
+ yield crawler
+
+
+@pytest.mark.asyncio
+async def test_network_error(crawler: AsyncWebCrawler):
+ url = "https://www.nonexistentwebsite123456789.com"
+ result = await crawler.arun(url=url, cache_mode=CacheMode.BYPASS)
+ assert not result.success
+ assert result.error_message
+ assert "Failed on navigating ACS-GOTO" in result.error_message
+
+
+@pytest.mark.asyncio
+async def test_timeout_error(crawler: AsyncWebCrawler):
+ # Simulating a timeout by using a very short timeout value
+ url = "https://www.nbcnews.com/business"
+ result = await crawler.arun(url=url, cache_mode=CacheMode.BYPASS, page_timeout=0.001)
+ assert not result.success
+ assert result.error_message
+ assert "timeout" in result.error_message.lower()
+
+
+@pytest.mark.asyncio
+@pytest.mark.skip("Invalid CSS selector not raised any more")
+async def test_invalid_css_selector(crawler: AsyncWebCrawler):
+ url = "https://www.nbcnews.com/business"
+ with pytest.raises(InvalidCSSSelectorError):
+ await crawler.arun(url=url, cache_mode=CacheMode.BYPASS, css_selector="invalid>>selector")
+
+
+@pytest.mark.asyncio
+async def test_js_execution_error(crawler: AsyncWebCrawler):
+ url = "https://www.nbcnews.com/business"
+ invalid_js = "This is not valid JavaScript code;"
+ result = await crawler.arun(url=url, cache_mode=CacheMode.BYPASS, js_code=invalid_js)
+ assert result.success
+ assert result.js_execution_result
+ assert result.js_execution_result["success"]
+ results: list[dict] = result.js_execution_result["results"]
+ assert results
+ assert not results[0]["success"]
+ assert "SyntaxError" in results[0]["error"]
+
+
+@pytest.mark.asyncio
+@pytest.mark.skip("The page is not empty any more")
+async def test_empty_page(crawler: AsyncWebCrawler):
+ # Use a URL that typically returns an empty page
+ url = "http://example.com/empty"
+ result = await crawler.arun(url=url, cache_mode=CacheMode.BYPASS)
+ assert result.success # The crawl itself should succeed
+ assert result.markdown is not None
+ assert not result.markdown.strip() # The markdown content should be empty or just whitespace
+
+
+@pytest.mark.asyncio
+@pytest.mark.skip("Rate limiting doesn't trigger")
+async def test_rate_limiting(crawler: AsyncWebCrawler):
+ # Simulate rate limiting by making multiple rapid requests
+ url = "https://www.nbcnews.com/business"
+ results = await asyncio.gather(*[crawler.arun(url=url, cache_mode=CacheMode.BYPASS) for _ in range(10)])
+ assert any(not result.success and result.error_message and "rate limit" in result.error_message.lower() for result in results)
+
+
+# Entry point for debugging
+if __name__ == "__main__":
+ import subprocess
+
+ sys.exit(subprocess.call(["pytest", *sys.argv[1:], sys.argv[0]]))
diff --git a/tests/async/test_evaluation_scraping_methods_performance.configs.py b/tests/async/test_evaluation_scraping_methods_performance_configs.py
similarity index 52%
rename from tests/async/test_evaluation_scraping_methods_performance.configs.py
rename to tests/async/test_evaluation_scraping_methods_performance_configs.py
index 797cf681c..e714a52f1 100644
--- a/tests/async/test_evaluation_scraping_methods_performance.configs.py
+++ b/tests/async/test_evaluation_scraping_methods_performance_configs.py
@@ -1,13 +1,20 @@
-import json
+import difflib
+import sys
import time
+from typing import Any, List, Tuple
+
+import pytest
+from _pytest.mark import ParameterSet
from bs4 import BeautifulSoup
+from bs4.element import Tag
+from lxml import etree
+from lxml import html as lhtml
+
from crawl4ai.content_scraping_strategy import (
- WebScrapingStrategy,
LXMLWebScrapingStrategy,
+ WebScrapingStrategy,
)
-from typing import Dict, List, Tuple
-import difflib
-from lxml import html as lhtml, etree
+from crawl4ai.models import Links, Media, ScrapingResult
def normalize_dom(element):
@@ -197,7 +204,7 @@ def generate_complicated_html():
Complicated Test Page
-
+
+
+
+
+
+
Example Domain
+
This domain is for use in illustrative examples in documents. You may use this
+ domain in literature without prior coordination or asking for permission.
+
More information...
+
+
+""")
+
+ app = web.Application()
+ app.router.add_get("/{route}", handler)
+ return await aiohttp_server(app)
+
+@pytest_asyncio.fixture
+async def browsers_manager(tmp_path: Path) -> AsyncGenerator[BrowsersManager, None]:
+ manager: BrowsersManager = BrowsersManager()
+ await manager.run(tmp_path)
+ yield manager
+ await manager.close()
+
+
+@pytest_asyncio.fixture
+async def manager() -> AsyncGenerator[BrowserManager, None]:
+ browser_config: BrowserConfig = BrowserConfig(browser_mode="builtin", headless=True, verbose=True)
+ manager: BrowserManager = BrowserManager(browser_config=browser_config)
+ await manager.start()
+ yield manager
+ await manager.close()
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("browser_type", ["webkit", "firefox"])
+async def test_not_supported(tmp_path: Path, browser_type: str):
+ browser_config: BrowserConfig = BrowserConfig(
+ browser_mode="builtin",
+ browser_type=browser_type,
+ headless=True,
+ debugging_port=0,
+ user_data_dir=str(tmp_path),
+ )
+ logger: AsyncLogger = AsyncLogger()
+ manager: BrowserManager = BrowserManager(browser_config=browser_config, logger=logger)
+ assert isinstance(manager._strategy, BuiltinBrowserStrategy), f"Wrong strategy type {manager._strategy.__class__.__name__}"
+ manager._strategy.shutting_down = True
+ with pytest.raises(Exception):
+ await manager.start()
- # Step 3: Start the manager to launch or connect to builtin browser
- print(f"\n{INFO}3. Starting the browser manager{RESET}")
+@pytest.mark.asyncio
+@pytest.mark.parametrize("browser_type", ["chromium"])
+async def test_ephemeral_port(tmp_path: Path, browser_type: str):
+ browser_config: BrowserConfig = BrowserConfig(
+ browser_mode="builtin",
+ browser_type=browser_type,
+ headless=True,
+ debugging_port=0,
+ user_data_dir=str(tmp_path),
+ )
+ logger: AsyncLogger = AsyncLogger()
+ manager: BrowserManager = BrowserManager(browser_config=browser_config, logger=logger)
+ assert isinstance(manager._strategy, BuiltinBrowserStrategy), f"Wrong strategy type {manager._strategy.__class__.__name__}"
+ manager._strategy.shutting_down = True
try:
await manager.start()
- print(f"{SUCCESS}Browser manager started successfully{RESET}")
- except Exception as e:
- print(f"{ERROR}Failed to start browser manager: {str(e)}{RESET}")
- return None
-
- # Step 4: Get browser info from the strategy
- print(f"\n{INFO}4. Getting browser information{RESET}")
- browser_info = manager._strategy.get_builtin_browser_info()
- if browser_info:
- print(f"{SUCCESS}Browser info retrieved:{RESET}")
- for key, value in browser_info.items():
- if key != "config": # Skip the verbose config section
- print(f" {key}: {value}")
-
- cdp_url = browser_info.get("cdp_url")
- print(f"{SUCCESS}CDP URL: {cdp_url}{RESET}")
- else:
- print(f"{ERROR}Failed to get browser information{RESET}")
- cdp_url = None
-
- # Save manager for later tests
- return manager, cdp_url
+ assert manager._strategy.config.debugging_port != 0, "Ephemeral port not assigned"
+ finally:
+ await manager.close()
+@pytest.mark.asyncio
+async def test_builtin_browser_creation(manager: BrowserManager):
+ """Test creating a builtin browser using the BrowserManager with BuiltinBrowserStrategy"""
+ # Check if we have a BuiltinBrowserStrategy
+ assert isinstance(manager._strategy, BuiltinBrowserStrategy), f"Wrong strategy type {manager._strategy.__class__.__name__}"
+
+ # Check we can get browser info from the strategy
+ strategy: BuiltinBrowserStrategy = manager._strategy
+ browser_info = strategy.get_browser_info()
+ assert browser_info, "Failed to get browser info"
+@pytest.mark.asyncio
async def test_page_operations(manager: BrowserManager):
"""Test page operations with the builtin browser"""
- print(
- f"\n{INFO}========== Testing Page Operations with Builtin Browser =========={RESET}"
- )
# Step 1: Get a single page
- print(f"\n{INFO}1. Getting a single page{RESET}")
- try:
- crawler_config = CrawlerRunConfig()
- page, context = await manager.get_page(crawler_config)
- print(f"{SUCCESS}Got page successfully{RESET}")
+ crawler_config = CrawlerRunConfig()
+ page, context = await manager.get_page(crawler_config)
- # Navigate to a test URL
- await page.goto("https://example.com")
- title = await page.title()
- print(f"{SUCCESS}Page title: {title}{RESET}")
+ # Navigate to a test URL
+ await page.goto("https://example.com")
+ title = await page.title()
- # Close the page
- await page.close()
- print(f"{SUCCESS}Page closed successfully{RESET}")
- except Exception as e:
- print(f"{ERROR}Page operation failed: {str(e)}{RESET}")
- return False
+ # Close the page
+ await page.close()
# Step 2: Get multiple pages
- print(f"\n{INFO}2. Getting multiple pages with get_pages(){RESET}")
- try:
- # Request 3 pages
- crawler_config = CrawlerRunConfig()
- pages = await manager.get_pages(crawler_config, count=3)
- print(f"{SUCCESS}Got {len(pages)} pages{RESET}")
-
- # Test each page
- for i, (page, context) in enumerate(pages):
- await page.goto(f"https://example.com?test={i}")
- title = await page.title()
- print(f"{SUCCESS}Page {i + 1} title: {title}{RESET}")
- await page.close()
-
- print(f"{SUCCESS}All pages tested and closed successfully{RESET}")
- except Exception as e:
- print(f"{ERROR}Multiple page operation failed: {str(e)}{RESET}")
- return False
- return True
+ # Request 3 pages
+ crawler_config = CrawlerRunConfig()
+ pages = await manager.get_pages(crawler_config, count=3)
+ # Test each page
+ for i, (page, context) in enumerate(pages):
+ response = await page.goto(f"https://example.com?test={i}")
+ assert response, f"Failed to load page {i + 1}"
+ title: str = await page.title()
+ assert title == "Example Domain", f"Expected title 'Example Domain', got '{title}'"
+ await page.close()
+@pytest.mark.asyncio
async def test_browser_status_management(manager: BrowserManager):
"""Test browser status and management operations"""
- print(f"\n{INFO}========== Testing Browser Status and Management =========={RESET}")
-
- # Step 1: Get browser status
- print(f"\n{INFO}1. Getting browser status{RESET}")
- try:
- status = await manager._strategy.get_builtin_browser_status()
- print(f"{SUCCESS}Browser status:{RESET}")
- print(f" Running: {status['running']}")
- print(f" CDP URL: {status['cdp_url']}")
- except Exception as e:
- print(f"{ERROR}Failed to get browser status: {str(e)}{RESET}")
- return False
+ assert isinstance(manager._strategy, BuiltinBrowserStrategy), f"Wrong strategy type {manager._strategy.__class__.__name__}"
+ status = await manager._strategy.get_builtin_browser_status()
# Step 2: Test killing the browser
- print(f"\n{INFO}2. Testing killing the browser{RESET}")
- try:
- result = await manager._strategy.kill_builtin_browser()
- if result:
- print(f"{SUCCESS}Browser killed successfully{RESET}")
- else:
- print(f"{ERROR}Failed to kill browser{RESET}")
- except Exception as e:
- print(f"{ERROR}Browser kill operation failed: {str(e)}{RESET}")
- return False
+ result = await manager._strategy.kill_builtin_browser()
+ assert result, "Failed to kill the browser"
# Step 3: Check status after kill
- print(f"\n{INFO}3. Checking status after kill{RESET}")
- try:
- status = await manager._strategy.get_builtin_browser_status()
- if not status["running"]:
- print(f"{SUCCESS}Browser is correctly reported as not running{RESET}")
- else:
- print(f"{ERROR}Browser is incorrectly reported as still running{RESET}")
- except Exception as e:
- print(f"{ERROR}Failed to get browser status: {str(e)}{RESET}")
- return False
+ status = await manager._strategy.get_builtin_browser_status()
+ assert status, "Failed to get browser status after kill"
+ assert not status["running"], "Browser is still running after kill"
# Step 4: Launch a new browser
- print(f"\n{INFO}4. Launching a new browser{RESET}")
- try:
- cdp_url = await manager._strategy.launch_builtin_browser(
- browser_type="chromium", headless=True
- )
- if cdp_url:
- print(f"{SUCCESS}New browser launched at: {cdp_url}{RESET}")
- else:
- print(f"{ERROR}Failed to launch new browser{RESET}")
- return False
- except Exception as e:
- print(f"{ERROR}Browser launch failed: {str(e)}{RESET}")
- return False
-
- return True
-
+ cdp_url = await manager._strategy.launch_builtin_browser(
+ browser_type="chromium", headless=True
+ )
+ assert cdp_url, "Failed to launch a new browser"
+@pytest.mark.asyncio
async def test_multiple_managers():
"""Test creating multiple BrowserManagers that use the same builtin browser"""
- print(f"\n{INFO}========== Testing Multiple Browser Managers =========={RESET}")
# Step 1: Create first manager
- print(f"\n{INFO}1. Creating first browser manager{RESET}")
- browser_config1 = (BrowserConfig(browser_mode="builtin", headless=True),)
- manager1 = BrowserManager(browser_config=browser_config1, logger=logger)
+ browser_config1: BrowserConfig = BrowserConfig(browser_mode="builtin", headless=True)
+ manager1: BrowserManager = BrowserManager(browser_config=browser_config1)
# Step 2: Create second manager
- print(f"\n{INFO}2. Creating second browser manager{RESET}")
- browser_config2 = BrowserConfig(browser_mode="builtin", headless=True)
- manager2 = BrowserManager(browser_config=browser_config2, logger=logger)
+ browser_config2: BrowserConfig = BrowserConfig(browser_mode="builtin", headless=True)
+ manager2: BrowserManager = BrowserManager(browser_config=browser_config2)
# Step 3: Start both managers (should connect to the same builtin browser)
- print(f"\n{INFO}3. Starting both managers{RESET}")
+ page1: Optional[Page] = None
+ page2: Optional[Page] = None
try:
await manager1.start()
- print(f"{SUCCESS}First manager started{RESET}")
-
await manager2.start()
- print(f"{SUCCESS}Second manager started{RESET}")
# Check if they got the same CDP URL
cdp_url1 = manager1._strategy.config.cdp_url
cdp_url2 = manager2._strategy.config.cdp_url
- if cdp_url1 == cdp_url2:
- print(
- f"{SUCCESS}Both managers connected to the same browser: {cdp_url1}{RESET}"
- )
- else:
- print(
- f"{WARNING}Managers connected to different browsers: {cdp_url1} and {cdp_url2}{RESET}"
- )
- except Exception as e:
- print(f"{ERROR}Failed to start managers: {str(e)}{RESET}")
- return False
+ assert cdp_url1 == cdp_url2, "CDP URLs do not match between managers"
- # Step 4: Test using both managers
- print(f"\n{INFO}4. Testing operations with both managers{RESET}")
- try:
+ # Step 4: Test using both managers
# First manager creates a page
page1, ctx1 = await manager1.get_page(CrawlerRunConfig())
await page1.goto("https://example.com")
- title1 = await page1.title()
- print(f"{SUCCESS}Manager 1 page title: {title1}{RESET}")
+ title1: str = await page1.title()
+ assert title1 == "Example Domain", f"Expected title 'Example Domain', got '{title1}'"
# Second manager creates a page
page2, ctx2 = await manager2.get_page(CrawlerRunConfig())
await page2.goto("https://example.org")
- title2 = await page2.title()
- print(f"{SUCCESS}Manager 2 page title: {title2}{RESET}")
-
- # Clean up
- await page1.close()
- await page2.close()
- except Exception as e:
- print(f"{ERROR}Failed to use both managers: {str(e)}{RESET}")
- return False
-
- # Step 5: Close both managers
- print(f"\n{INFO}5. Closing both managers{RESET}")
- try:
+ title2: str = await page2.title()
+ assert title2 == "Example Domain", f"Expected title 'Example Domain', got '{title2}'"
+ finally:
+ if page1:
+ await page1.close()
+ if page2:
+ await page2.close()
+ # Close both managers
await manager1.close()
- print(f"{SUCCESS}First manager closed{RESET}")
-
await manager2.close()
- print(f"{SUCCESS}Second manager closed{RESET}")
- except Exception as e:
- print(f"{ERROR}Failed to close managers: {str(e)}{RESET}")
- return False
- return True
-
-
-async def test_edge_cases():
- """Test edge cases like multiple starts, killing browser during operations, etc."""
- print(f"\n{INFO}========== Testing Edge Cases =========={RESET}")
-
- # Step 1: Test multiple starts with the same manager
- print(f"\n{INFO}1. Testing multiple starts with the same manager{RESET}")
- browser_config = BrowserConfig(browser_mode="builtin", headless=True)
- manager = BrowserManager(browser_config=browser_config, logger=logger)
+@pytest.mark.asyncio
+async def test_multiple_starts(manager: BrowserManager):
+ """Test multiple starts with the same manager."""
+ page: Optional[Page] = None
try:
- await manager.start()
- print(f"{SUCCESS}First start successful{RESET}")
-
# Try to start again
await manager.start()
- print(f"{SUCCESS}Second start completed without errors{RESET}")
# Test if it's still functional
page, context = await manager.get_page(CrawlerRunConfig())
+ assert page is not None, "Failed to create a page after multiple starts"
await page.goto("https://example.com")
- title = await page.title()
- print(
- f"{SUCCESS}Page operations work after multiple starts. Title: {title}{RESET}"
- )
- await page.close()
- except Exception as e:
- print(f"{ERROR}Multiple starts test failed: {str(e)}{RESET}")
- return False
+ title: str = await page.title()
+ assert title == "Example Domain", f"Expected title 'Example Domain', got '{title}'"
finally:
- await manager.close()
-
- # Step 2: Test killing the browser while manager is active
- print(f"\n{INFO}2. Testing killing the browser while manager is active{RESET}")
- manager = BrowserManager(browser_config=browser_config, logger=logger)
-
- try:
- await manager.start()
- print(f"{SUCCESS}Manager started{RESET}")
-
- # Kill the browser directly
- print(f"{INFO}Killing the browser...{RESET}")
- await manager._strategy.kill_builtin_browser()
- print(f"{SUCCESS}Browser killed{RESET}")
-
- # Try to get a page (should fail or launch a new browser)
- try:
- page, context = await manager.get_page(CrawlerRunConfig())
- print(
- f"{WARNING}Page request succeeded despite killed browser (might have auto-restarted){RESET}"
- )
- title = await page.title()
- print(f"{SUCCESS}Got page title: {title}{RESET}")
+ if page:
await page.close()
- except Exception as e:
- print(
- f"{SUCCESS}Page request failed as expected after browser was killed: {str(e)}{RESET}"
- )
- except Exception as e:
- print(f"{ERROR}Kill during operation test failed: {str(e)}{RESET}")
- return False
- finally:
await manager.close()
- return True
+@pytest.mark.asyncio
+async def test_kill_while_active(manager: BrowserManager):
+ """Test killing the browser while manager is active."""
+ assert isinstance(manager._strategy, BuiltinBrowserStrategy), f"Wrong strategy type {manager._strategy.__class__.__name__}"
+ await manager._strategy.kill_builtin_browser()
+ with pytest.raises(Exception):
+ # Try to get a page should fail
+ await manager.get_page(CrawlerRunConfig())
-async def cleanup_browsers():
- """Clean up any remaining builtin browsers"""
- print(f"\n{INFO}========== Cleaning Up Builtin Browsers =========={RESET}")
-
- browser_config = BrowserConfig(browser_mode="builtin", headless=True)
- manager = BrowserManager(browser_config=browser_config, logger=logger)
-
- try:
- # No need to start, just access the strategy directly
- strategy = manager._strategy
- if isinstance(strategy, BuiltinBrowserStrategy):
- result = await strategy.kill_builtin_browser()
- if result:
- print(f"{SUCCESS}Successfully killed all builtin browsers{RESET}")
- else:
- print(f"{WARNING}No builtin browsers found to kill{RESET}")
- else:
- print(f"{ERROR}Wrong strategy type: {strategy.__class__.__name__}{RESET}")
- except Exception as e:
- print(f"{ERROR}Cleanup failed: {str(e)}{RESET}")
- finally:
- # Just to be safe
- try:
- await manager.close()
- except:
- pass
-
-
-async def test_performance_scaling():
+@pytest.mark.asyncio
+@pytest.mark.timeout(30)
+async def test_performance_scaling(browsers_manager: BrowsersManager, test_server: TestServer):
"""Test performance with multiple browsers and pages.
This test creates multiple browsers on different ports,
spawns multiple pages per browser, and measures performance metrics.
"""
- print(f"\n{INFO}========== Testing Performance Scaling =========={RESET}")
-
- # Configuration parameters
- num_browsers = 10
- pages_per_browser = 10
- total_pages = num_browsers * pages_per_browser
- base_port = 9222
-
- # Set up a measuring mechanism for memory
- import psutil
- import gc
-
- # Force garbage collection before starting
- gc.collect()
- process = psutil.Process()
- initial_memory = process.memory_info().rss / 1024 / 1024 # in MB
- peak_memory = initial_memory
-
- # Report initial configuration
- print(
- f"{INFO}Test configuration: {num_browsers} browsers × {pages_per_browser} pages = {total_pages} total crawls{RESET}"
- )
-
- # List to track managers
- managers: List[BrowserManager] = []
- all_pages = []
-
-
-
- # Get crawl4ai home directory
- crawl4ai_home = os.path.expanduser("~/.crawl4ai")
- temp_dir = os.path.join(crawl4ai_home, "temp")
- os.makedirs(temp_dir, exist_ok=True)
-
- # Create all managers but don't start them yet
- manager_configs = []
- for i in range(num_browsers):
- port = base_port + i
- browser_config = BrowserConfig(
- browser_mode="builtin",
- headless=True,
- debugging_port=port,
- user_data_dir=os.path.join(temp_dir, f"browser_profile_{i}"),
- )
- manager = BrowserManager(browser_config=browser_config, logger=logger)
- manager._strategy.shutting_down = True
- manager_configs.append((manager, i, port))
-
- # Define async function to start a single manager
- async def start_manager(manager, index, port):
- try:
- await manager.start()
- return manager
- except Exception as e:
- print(
- f"{ERROR}Failed to start browser {index + 1} on port {port}: {str(e)}{RESET}"
- )
- return None
-
- # Start all managers in parallel
- start_tasks = [
- start_manager(manager, i, port) for manager, i, port in manager_configs
- ]
- started_managers = await asyncio.gather(*start_tasks)
-
- # Filter out None values (failed starts) and add to managers list
- managers = [m for m in started_managers if m is not None]
-
- if len(managers) == 0:
- print(f"{ERROR}All browser managers failed to start. Aborting test.{RESET}")
- return False
-
- if len(managers) < num_browsers:
- print(
- f"{WARNING}Only {len(managers)} out of {num_browsers} browser managers started successfully{RESET}"
- )
-
- # Create pages for each browser
- for i, manager in enumerate(managers):
- try:
- pages = await manager.get_pages(CrawlerRunConfig(), count=pages_per_browser)
- all_pages.extend(pages)
- except Exception as e:
- print(f"{ERROR}Failed to create pages for browser {i + 1}: {str(e)}{RESET}")
-
- # Check memory after page creation
- gc.collect()
- current_memory = process.memory_info().rss / 1024 / 1024
- peak_memory = max(peak_memory, current_memory)
+ assert browsers_manager.managers, "Failed to start any browser managers"
# Ask for confirmation before loading
confirmation = input(
f"{WARNING}Do you want to proceed with loading pages? (y/n): {RESET}"
- )
- # Step 1: Create and start multiple browser managers in parallel
- start_time = time.time()
-
- if confirmation.lower() == "y":
- load_start_time = time.time()
-
- # Function to load a single page
- async def load_page(page_ctx, index):
- page, _ = page_ctx
- try:
- await page.goto(f"https://example.com/page{index}", timeout=30000)
- title = await page.title()
- return title
- except Exception as e:
- return f"Error: {str(e)}"
-
- # Load all pages concurrently
- load_tasks = [load_page(page_ctx, i) for i, page_ctx in enumerate(all_pages)]
- load_results = await asyncio.gather(*load_tasks, return_exceptions=True)
-
- # Count successes and failures
- successes = sum(
- 1 for r in load_results if isinstance(r, str) and not r.startswith("Error")
- )
- failures = len(load_results) - successes
-
- load_time = time.time() - load_start_time
- total_test_time = time.time() - start_time
-
- # Check memory after loading (peak memory)
- gc.collect()
- current_memory = process.memory_info().rss / 1024 / 1024
- peak_memory = max(peak_memory, current_memory)
-
- # Calculate key metrics
- memory_per_page = peak_memory / successes if successes > 0 else 0
- time_per_crawl = total_test_time / successes if successes > 0 else 0
- crawls_per_second = successes / total_test_time if total_test_time > 0 else 0
- crawls_per_minute = crawls_per_second * 60
- crawls_per_hour = crawls_per_minute * 60
-
- # Print simplified performance summary
- from rich.console import Console
- from rich.table import Table
-
- console = Console()
-
- # Create a simple summary table
- table = Table(title="CRAWL4AI PERFORMANCE SUMMARY")
+ ) if sys.stdin.isatty() else "y"
- table.add_column("Metric", style="cyan")
- table.add_column("Value", style="green")
+ assert confirmation.lower() == "y", "User aborted the test"
- table.add_row("Total Crawls Completed", f"{successes}")
- table.add_row("Total Time", f"{total_test_time:.2f} seconds")
- table.add_row("Time Per Crawl", f"{time_per_crawl:.2f} seconds")
- table.add_row("Crawling Speed", f"{crawls_per_second:.2f} crawls/second")
- table.add_row("Projected Rate (1 minute)", f"{crawls_per_minute:.0f} crawls")
- table.add_row("Projected Rate (1 hour)", f"{crawls_per_hour:.0f} crawls")
- table.add_row("Peak Memory Usage", f"{peak_memory:.2f} MB")
- table.add_row("Memory Per Crawl", f"{memory_per_page:.2f} MB")
-
- # Display the table
- console.print(table)
-
- # Ask confirmation before cleanup
- confirmation = input(
- f"{WARNING}Do you want to proceed with cleanup? (y/n): {RESET}"
- )
- if confirmation.lower() != "y":
- print(f"{WARNING}Cleanup aborted by user{RESET}")
- return False
-
- # Close all pages
- for page, _ in all_pages:
- try:
- await page.close()
- except:
- pass
-
- # Close all managers
- for manager in managers:
- try:
- await manager.close()
- except:
- pass
-
- # Remove the temp directory
- import shutil
-
- if os.path.exists(temp_dir):
- shutil.rmtree(temp_dir)
-
- return True
-
-
-async def test_performance_scaling_lab( num_browsers: int = 10, pages_per_browser: int = 10):
- """Test performance with multiple browsers and pages.
-
- This test creates multiple browsers on different ports,
- spawns multiple pages per browser, and measures performance metrics.
- """
- print(f"\n{INFO}========== Testing Performance Scaling =========={RESET}")
-
- # Configuration parameters
- num_browsers = num_browsers
- pages_per_browser = pages_per_browser
- total_pages = num_browsers * pages_per_browser
- base_port = 9222
-
- # Set up a measuring mechanism for memory
- import psutil
- import gc
-
- # Force garbage collection before starting
- gc.collect()
- process = psutil.Process()
- initial_memory = process.memory_info().rss / 1024 / 1024 # in MB
- peak_memory = initial_memory
-
- # Report initial configuration
- print(
- f"{INFO}Test configuration: {num_browsers} browsers × {pages_per_browser} pages = {total_pages} total crawls{RESET}"
- )
-
- # List to track managers
- managers: List[BrowserManager] = []
- all_pages = []
-
- # Get crawl4ai home directory
- crawl4ai_home = os.path.expanduser("~/.crawl4ai")
- temp_dir = os.path.join(crawl4ai_home, "temp")
- os.makedirs(temp_dir, exist_ok=True)
-
- # Create all managers but don't start them yet
- manager_configs = []
- for i in range(num_browsers):
- port = base_port + i
- browser_config = BrowserConfig(
- browser_mode="builtin",
- headless=True,
- debugging_port=port,
- user_data_dir=os.path.join(temp_dir, f"browser_profile_{i}"),
- )
- manager = BrowserManager(browser_config=browser_config, logger=logger)
- manager._strategy.shutting_down = True
- manager_configs.append((manager, i, port))
-
- # Define async function to start a single manager
- async def start_manager(manager, index, port):
- try:
- await manager.start()
- return manager
- except Exception as e:
- print(
- f"{ERROR}Failed to start browser {index + 1} on port {port}: {str(e)}{RESET}"
- )
- return None
-
- # Start all managers in parallel
- start_tasks = [
- start_manager(manager, i, port) for manager, i, port in manager_configs
- ]
- started_managers = await asyncio.gather(*start_tasks)
-
- # Filter out None values (failed starts) and add to managers list
- managers = [m for m in started_managers if m is not None]
-
- if len(managers) == 0:
- print(f"{ERROR}All browser managers failed to start. Aborting test.{RESET}")
- return False
-
- if len(managers) < num_browsers:
- print(
- f"{WARNING}Only {len(managers)} out of {num_browsers} browser managers started successfully{RESET}"
- )
-
- # Create pages for each browser
- for i, manager in enumerate(managers):
- try:
- pages = await manager.get_pages(CrawlerRunConfig(), count=pages_per_browser)
- all_pages.extend(pages)
- except Exception as e:
- print(f"{ERROR}Failed to create pages for browser {i + 1}: {str(e)}{RESET}")
-
- # Check memory after page creation
- gc.collect()
- current_memory = process.memory_info().rss / 1024 / 1024
- peak_memory = max(peak_memory, current_memory)
-
- # Ask for confirmation before loading
- confirmation = input(
- f"{WARNING}Do you want to proceed with loading pages? (y/n): {RESET}"
- )
# Step 1: Create and start multiple browser managers in parallel
start_time = time.time()
-
- if confirmation.lower() == "y":
- load_start_time = time.time()
-
- # Function to load a single page
- async def load_page(page_ctx, index):
- page, _ = page_ctx
- try:
- await page.goto(f"https://example.com/page{index}", timeout=30000)
- title = await page.title()
- return title
- except Exception as e:
- return f"Error: {str(e)}"
-
- # Load all pages concurrently
- load_tasks = [load_page(page_ctx, i) for i, page_ctx in enumerate(all_pages)]
- load_results = await asyncio.gather(*load_tasks, return_exceptions=True)
-
- # Count successes and failures
- successes = sum(
- 1 for r in load_results if isinstance(r, str) and not r.startswith("Error")
- )
- failures = len(load_results) - successes
-
- load_time = time.time() - load_start_time
- total_test_time = time.time() - start_time
- # Check memory after loading (peak memory)
- gc.collect()
- current_memory = process.memory_info().rss / 1024 / 1024
- peak_memory = max(peak_memory, current_memory)
-
- # Calculate key metrics
- memory_per_page = peak_memory / successes if successes > 0 else 0
- time_per_crawl = total_test_time / successes if successes > 0 else 0
- crawls_per_second = successes / total_test_time if total_test_time > 0 else 0
- crawls_per_minute = crawls_per_second * 60
- crawls_per_hour = crawls_per_minute * 60
-
- # Print simplified performance summary
- from rich.console import Console
- from rich.table import Table
-
- console = Console()
-
- # Create a simple summary table
- table = Table(title="CRAWL4AI PERFORMANCE SUMMARY")
-
- table.add_column("Metric", style="cyan")
- table.add_column("Value", style="green")
-
- table.add_row("Total Crawls Completed", f"{successes}")
- table.add_row("Total Time", f"{total_test_time:.2f} seconds")
- table.add_row("Time Per Crawl", f"{time_per_crawl:.2f} seconds")
- table.add_row("Crawling Speed", f"{crawls_per_second:.2f} crawls/second")
- table.add_row("Projected Rate (1 minute)", f"{crawls_per_minute:.0f} crawls")
- table.add_row("Projected Rate (1 hour)", f"{crawls_per_hour:.0f} crawls")
- table.add_row("Peak Memory Usage", f"{peak_memory:.2f} MB")
- table.add_row("Memory Per Crawl", f"{memory_per_page:.2f} MB")
-
- # Display the table
- console.print(table)
-
- # Ask confirmation before cleanup
- confirmation = input(
- f"{WARNING}Do you want to proceed with cleanup? (y/n): {RESET}"
- )
- if confirmation.lower() != "y":
- print(f"{WARNING}Cleanup aborted by user{RESET}")
- return False
-
- # Close all pages
- for page, _ in all_pages:
+ # Function to load a single page
+ url = test_server.make_url("/page")
+ async def load_page(page_ctx: Tuple[Page, Any], index):
+ page, _ = page_ctx
try:
- await page.close()
- except:
- pass
-
- # Close all managers
- for manager in managers:
- try:
- await manager.close()
- except:
- pass
-
- # Remove the temp directory
- import shutil
-
- if os.path.exists(temp_dir):
- shutil.rmtree(temp_dir)
-
- return True
+ response: Optional[Response] = await page.goto(f"{url}{index}", timeout=5000) # example.com tends to hang connections under load.
+ if response is None:
+ print(f"{ERROR}Failed to load page {index}: No response{RESET}")
+ return "Error: No response"
+ if response.status != codes.OK:
+ print(f"{ERROR}Failed to load page {index}: {response.status}{RESET}")
+ return f"Error: {response.status}"
+ return await page.title()
+ except Exception as e:
+ print(f"{ERROR}Failed to load page {index}: {str(e)}{RESET}")
+ return f"Error: {str(e)}"
-async def main():
- """Run all tests"""
- try:
- print(f"{INFO}Starting builtin browser tests with browser module{RESET}")
+ # Load all pages concurrently
+ load_tasks = [load_page(page_ctx, i + 1) for i, page_ctx in enumerate(browsers_manager.all_pages)]
+ load_results = await asyncio.gather(*load_tasks)
- # # Run browser creation test
- # manager, cdp_url = await test_builtin_browser_creation()
- # if not manager:
- # print(f"{ERROR}Browser creation failed, cannot continue tests{RESET}")
- # return
+ # Count successes and failures
+ successes = sum(
+ 1 for r in load_results if isinstance(r, str) and not r.startswith("Error")
+ )
+ failures = len(load_results) - successes
- # # Run page operations test
- # await test_page_operations(manager)
+ assert not failures, f"Failed to load {failures} pages"
- # # Run browser status and management test
- # await test_browser_status_management(manager)
+ total_test_time = time.time() - start_time
- # # Close manager before multiple manager test
- # await manager.close()
+ # Check memory after loading (peak memory)
+ browsers_manager.check_memory()
- # Run multiple managers test
- # await test_multiple_managers()
+ # Calculate key metrics
+ memory_per_page = browsers_manager.peak_memory / successes if successes > 0 else 0
+ time_per_crawl = total_test_time / successes if successes > 0 else 0
+ crawls_per_second = successes / total_test_time if total_test_time > 0 else 0
+ crawls_per_minute = crawls_per_second * 60
+ crawls_per_hour = crawls_per_minute * 60
- # Run performance scaling test
- await test_performance_scaling()
- # Run cleanup test
- # await cleanup_browsers()
+ # Print simplified performance summary
+ from rich.console import Console
+ from rich.table import Table
- # Run edge cases test
- # await test_edge_cases()
+ console = Console()
- print(f"\n{SUCCESS}All tests completed!{RESET}")
+ # Create a simple summary table
+ table = Table(title="CRAWL4AI PERFORMANCE SUMMARY")
- except Exception as e:
- print(f"\n{ERROR}Test failed with error: {str(e)}{RESET}")
- import traceback
+ table.add_column("Metric", style="cyan")
+ table.add_column("Value", style="green")
- traceback.print_exc()
- finally:
- # Clean up: kill any remaining builtin browsers
- await cleanup_browsers()
- print(f"{SUCCESS}Test cleanup complete{RESET}")
+ table.add_row("Total Crawls Completed", f"{successes}")
+ table.add_row("Total Time", f"{total_test_time:.2f} seconds")
+ table.add_row("Time Per Crawl", f"{time_per_crawl:.2f} seconds")
+ table.add_row("Crawling Speed", f"{crawls_per_second:.2f} crawls/second")
+ table.add_row("Projected Rate (1 minute)", f"{crawls_per_minute:.0f} crawls")
+ table.add_row("Projected Rate (1 hour)", f"{crawls_per_hour:.0f} crawls")
+ table.add_row("Peak Memory Usage", f"{browsers_manager.peak_memory:.2f} MB")
+ table.add_row("Memory Per Crawl", f"{memory_per_page:.2f} MB")
+ # Display the table
+ console.print(table)
if __name__ == "__main__":
- asyncio.run(main())
+ import subprocess
+
+ sys.exit(subprocess.call(["pytest", *sys.argv[1:], sys.argv[0]]))
diff --git a/tests/browser/test_builtin_strategy.py b/tests/browser/test_builtin_strategy.py
index 7c435b3de..0f8643013 100644
--- a/tests/browser/test_builtin_strategy.py
+++ b/tests/browser/test_builtin_strategy.py
@@ -4,13 +4,8 @@
and serve as functional tests.
"""
-import asyncio
-import os
import sys
-
-# Add the project root to Python path if running directly
-if __name__ == "__main__":
- sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))
+import pytest
from crawl4ai.browser import BrowserManager
from crawl4ai.async_configs import BrowserConfig, CrawlerRunConfig
@@ -19,6 +14,7 @@
# Create a logger for clear terminal output
logger = AsyncLogger(verbose=True, log_file=None)
+@pytest.mark.asyncio
async def test_builtin_browser():
"""Test using a builtin browser that persists between sessions."""
logger.info("Testing builtin browser", tag="TEST")
@@ -68,10 +64,11 @@ async def test_builtin_browser():
logger.error(f"Test failed: {str(e)}", tag="TEST")
try:
await manager.close()
- except:
+ except Exception:
pass
return False
+@pytest.mark.asyncio
async def test_builtin_browser_status():
"""Test getting status of the builtin browser."""
logger.info("Testing builtin browser status", tag="TEST")
@@ -135,26 +132,11 @@ async def test_builtin_browser_status():
# Try to kill the builtin browser to clean up
strategy2 = BuiltinBrowserStrategy(browser_config, logger)
await strategy2.kill_builtin_browser()
- except:
+ except Exception:
pass
return False
-async def run_tests():
- """Run all tests sequentially."""
- results = []
-
- results.append(await test_builtin_browser())
- results.append(await test_builtin_browser_status())
-
- # Print summary
- total = len(results)
- passed = sum(results)
- logger.info(f"Tests complete: {passed}/{total} passed", tag="SUMMARY")
-
- if passed == total:
- logger.success("All tests passed!", tag="SUMMARY")
- else:
- logger.error(f"{total - passed} tests failed", tag="SUMMARY")
-
if __name__ == "__main__":
- asyncio.run(run_tests())
+ import subprocess
+
+ sys.exit(subprocess.call(["pytest", *sys.argv[1:], sys.argv[0]]))
diff --git a/tests/browser/test_cdp_strategy.py b/tests/browser/test_cdp_strategy.py
index abadf42a2..54bb8bfab 100644
--- a/tests/browser/test_cdp_strategy.py
+++ b/tests/browser/test_cdp_strategy.py
@@ -4,13 +4,8 @@
and serve as functional tests.
"""
-import asyncio
-import os
import sys
-
-# Add the project root to Python path if running directly
-if __name__ == "__main__":
- sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))
+import pytest
from crawl4ai.browser import BrowserManager
from crawl4ai.async_configs import BrowserConfig, CrawlerRunConfig
@@ -19,6 +14,7 @@
# Create a logger for clear terminal output
logger = AsyncLogger(verbose=True, log_file=None)
+@pytest.mark.asyncio
async def test_cdp_launch_connect():
"""Test launching a browser and connecting via CDP."""
logger.info("Testing launch and connect via CDP", tag="TEST")
@@ -56,10 +52,11 @@ async def test_cdp_launch_connect():
logger.error(f"Test failed: {str(e)}", tag="TEST")
try:
await manager.close()
- except:
+ except Exception:
pass
return False
+@pytest.mark.asyncio
async def test_cdp_with_user_data_dir():
"""Test CDP browser with a user data directory."""
logger.info("Testing CDP browser with user data directory", tag="TEST")
@@ -124,25 +121,26 @@ async def test_cdp_with_user_data_dir():
# Remove temporary directory
import shutil
shutil.rmtree(user_data_dir, ignore_errors=True)
- logger.info(f"Removed temporary user data directory", tag="TEST")
-
+ logger.info("Removed temporary user data directory", tag="TEST")
+
return has_test_cookie and has_test_cookie2
except Exception as e:
logger.error(f"Test failed: {str(e)}", tag="TEST")
try:
await manager.close()
- except:
+ except Exception:
pass
# Clean up temporary directory
try:
import shutil
shutil.rmtree(user_data_dir, ignore_errors=True)
- except:
+ except Exception:
pass
return False
+@pytest.mark.asyncio
async def test_cdp_session_management():
"""Test session management with CDP browser."""
logger.info("Testing session management with CDP browser", tag="TEST")
@@ -186,8 +184,8 @@ async def test_cdp_session_management():
# Kill first session
await manager.kill_session(session1_id)
- logger.info(f"Killed session 1", tag="TEST")
-
+ logger.info("Killed session 1", tag="TEST")
+
# Verify second session still works
data2 = await page2.evaluate("localStorage.getItem('session2_data')")
logger.info(f"Session 2 still functional after killing session 1, data: {data2}", tag="TEST")
@@ -201,27 +199,11 @@ async def test_cdp_session_management():
logger.error(f"Test failed: {str(e)}", tag="TEST")
try:
await manager.close()
- except:
+ except Exception:
pass
return False
-async def run_tests():
- """Run all tests sequentially."""
- results = []
-
- # results.append(await test_cdp_launch_connect())
- # results.append(await test_cdp_with_user_data_dir())
- results.append(await test_cdp_session_management())
-
- # Print summary
- total = len(results)
- passed = sum(results)
- logger.info(f"Tests complete: {passed}/{total} passed", tag="SUMMARY")
-
- if passed == total:
- logger.success("All tests passed!", tag="SUMMARY")
- else:
- logger.error(f"{total - passed} tests failed", tag="SUMMARY")
-
if __name__ == "__main__":
- asyncio.run(run_tests())
+ import subprocess
+
+ sys.exit(subprocess.call(["pytest", *sys.argv[1:], sys.argv[0]]))
diff --git a/tests/browser/test_combined.py b/tests/browser/test_combined.py
deleted file mode 100644
index b5bce3cda..000000000
--- a/tests/browser/test_combined.py
+++ /dev/null
@@ -1,77 +0,0 @@
-"""Combined test runner for all browser module tests.
-
-This script runs all the browser module tests in sequence and
-provides a comprehensive summary.
-"""
-
-import asyncio
-import os
-import sys
-import time
-
-# Add the project root to Python path if running directly
-if __name__ == "__main__":
- sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))
-
-from crawl4ai.async_logger import AsyncLogger
-
-# Create a logger for clear terminal output
-logger = AsyncLogger(verbose=True, log_file=None)
-
-async def run_test_module(module_name, header):
- """Run all tests in a module and return results."""
- logger.info(f"\n{'-'*30}", tag="TEST")
- logger.info(f"RUNNING: {header}", tag="TEST")
- logger.info(f"{'-'*30}", tag="TEST")
-
- # Import the module dynamically
- module = __import__(f"tests.browser.{module_name}", fromlist=["run_tests"])
-
- # Track time for performance measurement
- start_time = time.time()
-
- # Run the tests
- await module.run_tests()
-
- # Calculate time taken
- time_taken = time.time() - start_time
- logger.info(f"Time taken: {time_taken:.2f} seconds", tag="TIMING")
-
- return time_taken
-
-async def main():
- """Run all test modules."""
- logger.info("STARTING COMPREHENSIVE BROWSER MODULE TESTS", tag="MAIN")
-
- # List of test modules to run
- test_modules = [
- ("test_browser_manager", "Browser Manager Tests"),
- ("test_playwright_strategy", "Playwright Strategy Tests"),
- ("test_cdp_strategy", "CDP Strategy Tests"),
- ("test_builtin_strategy", "Builtin Browser Strategy Tests"),
- ("test_profiles", "Profile Management Tests")
- ]
-
- # Run each test module
- timings = {}
- for module_name, header in test_modules:
- try:
- time_taken = await run_test_module(module_name, header)
- timings[module_name] = time_taken
- except Exception as e:
- logger.error(f"Error running {module_name}: {str(e)}", tag="ERROR")
-
- # Print summary
- logger.info("\n\nTEST SUMMARY:", tag="SUMMARY")
- logger.info(f"{'-'*50}", tag="SUMMARY")
- for module_name, header in test_modules:
- if module_name in timings:
- logger.info(f"{header}: {timings[module_name]:.2f} seconds", tag="SUMMARY")
- else:
- logger.error(f"{header}: FAILED TO RUN", tag="SUMMARY")
- logger.info(f"{'-'*50}", tag="SUMMARY")
- total_time = sum(timings.values())
- logger.info(f"Total time: {total_time:.2f} seconds", tag="SUMMARY")
-
-if __name__ == "__main__":
- asyncio.run(main())
diff --git a/tests/browser/test_launch_standalone.py b/tests/browser/test_launch_standalone.py
index d60b12f3f..28f2ade11 100644
--- a/tests/browser/test_launch_standalone.py
+++ b/tests/browser/test_launch_standalone.py
@@ -1,17 +1,21 @@
+import os
+import sys
+import pytest
from crawl4ai.browser_profiler import BrowserProfiler
-import asyncio
+@pytest.mark.asyncio
+@pytest.mark.skip(reason="Requires user interaction to stop the browser, more work needed")
+async def test_standalone_browser():
+ profiler = BrowserProfiler()
+ cdp_url = await profiler.launch_standalone_browser(
+ browser_type="chromium",
+ user_data_dir=os.path.expanduser("~/.crawl4ai/browser_profile/test-browser-data"),
+ debugging_port=9222,
+ headless=False
+ )
+ assert cdp_url is not None, "Failed to launch standalone browser"
if __name__ == "__main__":
- # Test launching a standalone browser
- async def test_standalone_browser():
- profiler = BrowserProfiler()
- cdp_url = await profiler.launch_standalone_browser(
- browser_type="chromium",
- user_data_dir="~/.crawl4ai/browser_profile/test-browser-data",
- debugging_port=9222,
- headless=False
- )
- print(f"CDP URL: {cdp_url}")
+ import subprocess
- asyncio.run(test_standalone_browser())
\ No newline at end of file
+ sys.exit(subprocess.call(["pytest", *sys.argv[1:], sys.argv[0]]))
\ No newline at end of file
diff --git a/tests/browser/test_parallel_crawling.py b/tests/browser/test_parallel_crawling.py
index 9e72f06e3..92657ea16 100644
--- a/tests/browser/test_parallel_crawling.py
+++ b/tests/browser/test_parallel_crawling.py
@@ -6,14 +6,8 @@
"""
import asyncio
-import os
-import sys
import time
-from typing import List
-
-# Add the project root to Python path if running directly
-if __name__ == "__main__":
- sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))
+import pytest
from crawl4ai.browser import BrowserManager
from crawl4ai.async_configs import BrowserConfig, CrawlerRunConfig
@@ -22,6 +16,7 @@
# Create a logger for clear terminal output
logger = AsyncLogger(verbose=True, log_file=None)
+@pytest.mark.asyncio
async def test_get_pages_basic():
"""Test basic functionality of get_pages method."""
logger.info("Testing basic get_pages functionality", tag="TEST")
@@ -53,10 +48,11 @@ async def test_get_pages_basic():
logger.error(f"Test failed: {str(e)}", tag="TEST")
try:
await manager.close()
- except:
+ except Exception:
pass
return False
+@pytest.mark.asyncio
async def test_parallel_approaches_comparison():
"""Compare two parallel crawling approaches:
1. Create a page for each URL on-demand (get_page + gather)
@@ -148,10 +144,11 @@ async def fetch_title_approach2(page_ctx, url):
logger.error(f"Test failed: {str(e)}", tag="TEST")
try:
await manager.close()
- except:
+ except Exception:
pass
return False
+@pytest.mark.asyncio
async def test_multi_browser_scaling(num_browsers=3, pages_per_browser=5):
"""Test performance with multiple browsers and pages per browser.
Compares two approaches:
@@ -166,8 +163,7 @@ async def test_multi_browser_scaling(num_browsers=3, pages_per_browser=5):
# Create browser managers
managers = []
- base_port = 9222
-
+
try:
# Start all browsers in parallel
start_tasks = []
@@ -268,7 +264,7 @@ async def fetch_title_approach2(page_ctx, url):
for manager in managers:
try:
await manager.close()
- except:
+ except Exception:
pass
return False
@@ -469,7 +465,7 @@ async def fetch_title(page_ctx, url):
for manager in managers:
try:
await manager.close()
- except:
+ except Exception:
pass
# Print summary of all configurations
@@ -880,7 +876,7 @@ async def run_tests():
if configs:
# Show the optimal configuration
optimal = configs["optimal"]
- print(f"\n🎯 Recommended configuration for production use:")
+ print("\n🎯 Recommended configuration for production use:")
print(f" {optimal['browser_count']} browsers with distribution {optimal['distribution']}")
print(f" Estimated performance: {optimal['pages_per_second']:.1f} pages/second")
results.append(True)
diff --git a/tests/browser/test_playwright_strategy.py b/tests/browser/test_playwright_strategy.py
index 2344c9bae..c0f209b33 100644
--- a/tests/browser/test_playwright_strategy.py
+++ b/tests/browser/test_playwright_strategy.py
@@ -4,13 +4,8 @@
and serve as functional tests.
"""
-import asyncio
-import os
import sys
-
-# Add the project root to Python path if running directly
-if __name__ == "__main__":
- sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))
+import pytest
from crawl4ai.browser import BrowserManager
from crawl4ai.async_configs import BrowserConfig, CrawlerRunConfig
@@ -19,6 +14,7 @@
# Create a logger for clear terminal output
logger = AsyncLogger(verbose=True, log_file=None)
+@pytest.mark.asyncio
async def test_playwright_basic():
"""Test basic Playwright browser functionality."""
logger.info("Testing standard Playwright browser", tag="TEST")
@@ -63,10 +59,11 @@ async def test_playwright_basic():
# Ensure cleanup
try:
await manager.close()
- except:
+ except Exception:
pass
return False
+@pytest.mark.asyncio
async def test_playwright_text_mode():
"""Test Playwright browser in text-only mode."""
logger.info("Testing Playwright text mode", tag="TEST")
@@ -106,7 +103,7 @@ async def test_playwright_text_mode():
await page.goto("https://picsum.photos/", wait_until="domcontentloaded")
request = await request_info.value
has_images = True
- except:
+ except Exception:
# Timeout without image requests means text mode is working
has_images = False
@@ -122,10 +119,11 @@ async def test_playwright_text_mode():
# Ensure cleanup
try:
await manager.close()
- except:
+ except Exception:
pass
return False
+@pytest.mark.asyncio
async def test_playwright_context_reuse():
"""Test context caching and reuse with identical configurations."""
logger.info("Testing context reuse with identical configurations", tag="TEST")
@@ -178,10 +176,11 @@ async def test_playwright_context_reuse():
# Ensure cleanup
try:
await manager.close()
- except:
+ except Exception:
pass
return False
+@pytest.mark.asyncio
async def test_playwright_session_management():
"""Test session management with Playwright browser."""
logger.info("Testing session management with Playwright browser", tag="TEST")
@@ -225,8 +224,8 @@ async def test_playwright_session_management():
# Kill first session
await manager.kill_session(session1_id)
- logger.info(f"Killed session 1", tag="TEST")
-
+ logger.info("Killed session 1", tag="TEST")
+
# Verify second session still works
data2 = await page2.evaluate("localStorage.getItem('playwright_session2_data')")
logger.info(f"Session 2 still functional after killing session 1, data: {data2}", tag="TEST")
@@ -240,28 +239,12 @@ async def test_playwright_session_management():
logger.error(f"Test failed: {str(e)}", tag="TEST")
try:
await manager.close()
- except:
+ except Exception:
pass
return False
-async def run_tests():
- """Run all tests sequentially."""
- results = []
-
- results.append(await test_playwright_basic())
- results.append(await test_playwright_text_mode())
- results.append(await test_playwright_context_reuse())
- results.append(await test_playwright_session_management())
-
- # Print summary
- total = len(results)
- passed = sum(results)
- logger.info(f"Tests complete: {passed}/{total} passed", tag="SUMMARY")
-
- if passed == total:
- logger.success("All tests passed!", tag="SUMMARY")
- else:
- logger.error(f"{total - passed} tests failed", tag="SUMMARY")
if __name__ == "__main__":
- asyncio.run(run_tests())
+ import subprocess
+
+ sys.exit(subprocess.call(["pytest", *sys.argv[1:], sys.argv[0]]))
diff --git a/tests/browser/test_profiles.py b/tests/browser/test_profiles.py
index 8325b561a..0a8c32c75 100644
--- a/tests/browser/test_profiles.py
+++ b/tests/browser/test_profiles.py
@@ -4,15 +4,11 @@
and serve as functional tests.
"""
-import asyncio
import os
import sys
import uuid
import shutil
-
-# Add the project root to Python path if running directly
-if __name__ == "__main__":
- sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))
+import pytest
from crawl4ai.browser import BrowserManager, BrowserProfileManager
from crawl4ai.async_configs import BrowserConfig, CrawlerRunConfig
@@ -21,6 +17,7 @@
# Create a logger for clear terminal output
logger = AsyncLogger(verbose=True, log_file=None)
+@pytest.mark.asyncio
async def test_profile_creation():
"""Test creating and managing browser profiles."""
logger.info("Testing profile creation and management", tag="TEST")
@@ -75,10 +72,11 @@ async def test_profile_creation():
try:
if os.path.exists(profile_path):
shutil.rmtree(profile_path, ignore_errors=True)
- except:
+ except Exception:
pass
return False
+@pytest.mark.asyncio
async def test_profile_with_browser():
"""Test using a profile with a browser."""
logger.info("Testing using a profile with a browser", tag="TEST")
@@ -151,26 +149,11 @@ async def test_profile_with_browser():
try:
if profile_path and os.path.exists(profile_path):
shutil.rmtree(profile_path, ignore_errors=True)
- except:
+ except Exception:
pass
return False
-async def run_tests():
- """Run all tests sequentially."""
- results = []
-
- results.append(await test_profile_creation())
- results.append(await test_profile_with_browser())
-
- # Print summary
- total = len(results)
- passed = sum(results)
- logger.info(f"Tests complete: {passed}/{total} passed", tag="SUMMARY")
-
- if passed == total:
- logger.success("All tests passed!", tag="SUMMARY")
- else:
- logger.error(f"{total - passed} tests failed", tag="SUMMARY")
-
if __name__ == "__main__":
- asyncio.run(run_tests())
+ import subprocess
+
+ sys.exit(subprocess.call(["pytest", *sys.argv[1:], sys.argv[0]]))
\ No newline at end of file
diff --git a/tests/cli/test_cli.py b/tests/cli/test_cli.py
index b7416dc29..f9ddea738 100644
--- a/tests/cli/test_cli.py
+++ b/tests/cli/test_cli.py
@@ -1,17 +1,22 @@
-import pytest
-from click.testing import CliRunner
-from pathlib import Path
import json
-import yaml
-from crawl4ai.cli import cli, load_config_file, parse_key_values
-import tempfile
import os
+import sys
+import tempfile
+from pathlib import Path
+
import click
+import pytest
+import yaml
+from click.testing import CliRunner, Result
+
+from crawl4ai.cli import cli, load_config_file, parse_key_values
+
@pytest.fixture
-def runner():
+def runner() -> CliRunner:
return CliRunner()
+
@pytest.fixture
def temp_config_dir():
with tempfile.TemporaryDirectory() as tmpdir:
@@ -21,8 +26,9 @@ def temp_config_dir():
if old_home:
os.environ['HOME'] = old_home
+
@pytest.fixture
-def sample_configs(temp_config_dir):
+def sample_configs(temp_config_dir: Path) -> dict[str, str]:
configs = {
'browser.yml': {
'headless': True,
@@ -59,20 +65,21 @@ def sample_configs(temp_config_dir):
return {name: str(temp_config_dir / name) for name in configs}
class TestCLIBasics:
- def test_help(self, runner):
- result = runner.invoke(cli, ['--help'])
+ def test_help(self, runner: CliRunner):
+ result: Result = runner.invoke(cli, ['--help'])
assert result.exit_code == 0
assert 'Crawl4AI CLI' in result.output
- def test_examples(self, runner):
- result = runner.invoke(cli, ['--example'])
+ def test_examples(self, runner: CliRunner):
+ result: Result = runner.invoke(cli, ['examples'])
assert result.exit_code == 0
assert 'Examples' in result.output
def test_missing_url(self, runner):
- result = runner.invoke(cli)
+ result: Result = runner.invoke(cli, ['crawl'])
assert result.exit_code != 0
- assert 'URL argument is required' in result.output
+ assert "Error: Missing argument 'URL'" in result.output
+
class TestConfigParsing:
def test_parse_key_values_basic(self):
@@ -99,35 +106,38 @@ def test_load_nonexistent_config(self):
load_config_file('nonexistent.yml')
class TestLLMConfig:
- def test_llm_config_creation(self, temp_config_dir, runner):
+ def test_llm_config_creation(self, temp_config_dir: Path, runner: CliRunner):
def input_simulation(inputs):
return runner.invoke(cli, ['https://example.com', '-q', 'test question'],
input='\n'.join(inputs))
class TestCrawlingFeatures:
- def test_basic_crawl(self, runner):
- result = runner.invoke(cli, ['https://example.com'])
+ def test_basic_crawl(self, runner: CliRunner):
+ result: Result = runner.invoke(cli, ['crawl', 'https://example.com'])
assert result.exit_code == 0
class TestErrorHandling:
- def test_invalid_config_file(self, runner):
- result = runner.invoke(cli, [
+ def test_invalid_config_file(self, runner: CliRunner):
+ result: Result = runner.invoke(cli, [
'https://example.com',
'--browser-config', 'nonexistent.yml'
])
assert result.exit_code != 0
- def test_invalid_schema(self, runner, temp_config_dir):
+ def test_invalid_schema(self, runner: CliRunner, temp_config_dir: Path):
invalid_schema = temp_config_dir / 'invalid_schema.json'
with open(invalid_schema, 'w') as f:
f.write('invalid json')
-
- result = runner.invoke(cli, [
+
+ result: Result = runner.invoke(cli, [
'https://example.com',
'--schema', str(invalid_schema)
])
assert result.exit_code != 0
-if __name__ == '__main__':
- pytest.main(['-v', '-s', '--tb=native', __file__])
\ No newline at end of file
+
+if __name__ == "__main__":
+ import subprocess
+
+ sys.exit(subprocess.call(["pytest", *sys.argv[1:], sys.argv[0]]))
diff --git a/tests/docker/__init__.py b/tests/docker/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/tests/docker/common.py b/tests/docker/common.py
new file mode 100644
index 000000000..59bd47319
--- /dev/null
+++ b/tests/docker/common.py
@@ -0,0 +1,51 @@
+from typing import List
+
+import pytest
+from _pytest.mark import ParameterSet
+from httpx import ASGITransport, AsyncClient
+
+from crawl4ai.docker_client import Crawl4aiDockerClient
+from deploy.docker.server import app
+
+TEST_URLS = [
+ "example.com",
+ "https://www.python.org",
+ "https://news.ycombinator.com/news",
+ "https://github.com/trending",
+]
+BASE_URL = "http://localhost:8000"
+
+
+def async_client() -> AsyncClient:
+ """Create an async client for the server.
+
+ This can be used to test the API server without running the server."""
+ return AsyncClient(transport=ASGITransport(app=app), base_url=BASE_URL)
+
+
+def docker_client() -> Crawl4aiDockerClient:
+ """Crawl4aiDockerClient docker client via local transport.
+
+ This can be used to test the API server without running the server."""
+ return Crawl4aiDockerClient(transport=ASGITransport(app=app), verbose=True)
+
+
+def markdown_params(urls: List[str] = TEST_URLS) -> List[ParameterSet]:
+ """Parameters for markdown endpoint tests with different filters"""
+ tests = []
+ for url in urls:
+ for filter_type in ["raw", "fit", "bm25", "llm"]:
+ for cache in ["0", "1"]:
+ params: dict[str, str] = {"f": filter_type, "c": cache}
+ if filter_type in ["bm25", "llm"]:
+ params["q"] = "extract main content"
+
+ tests.append(
+ pytest.param(
+ url,
+ params,
+ id=f"{url} {filter_type}" + (" cached" if cache == "1" else ""),
+ )
+ )
+
+ return tests
diff --git a/tests/docker/test_config_object.py b/tests/docker/test_config_object.py
index 94a30f058..b723cf7d9 100644
--- a/tests/docker/test_config_object.py
+++ b/tests/docker/test_config_object.py
@@ -1,4 +1,5 @@
import json
+import sys
from crawl4ai import (
CrawlerRunConfig,
DefaultMarkdownGenerator,
@@ -8,9 +9,8 @@
CacheMode
)
from crawl4ai.deep_crawling import BFSDeepCrawlStrategy
-from crawl4ai.deep_crawling.filters import FastFilterChain
-from crawl4ai.deep_crawling.filters import FastContentTypeFilter, FastDomainFilter
-from crawl4ai.deep_crawling.scorers import FastKeywordRelevanceScorer
+from crawl4ai.deep_crawling.filters import FilterChain, ContentTypeFilter, DomainFilter
+from crawl4ai.deep_crawling.scorers import KeywordRelevanceScorer
def create_test_config() -> CrawlerRunConfig:
# Set up content filtering and markdown generation
@@ -35,12 +35,12 @@ def create_test_config() -> CrawlerRunConfig:
extraction_strategy = JsonCssExtractionStrategy(schema=extraction_schema)
# Set up deep crawling
- filter_chain = FastFilterChain([
- FastContentTypeFilter(["text/html"]),
- FastDomainFilter(blocked_domains=["ads.*"])
+ filter_chain = FilterChain([
+ ContentTypeFilter(["text/html"]),
+ DomainFilter(blocked_domains=["ads.*"]),
])
- url_scorer = FastKeywordRelevanceScorer(
+ url_scorer = KeywordRelevanceScorer(
keywords=["article", "blog"],
weight=1.0
)
@@ -104,10 +104,12 @@ def test_config_serialization_cycle():
# Verify deep crawl strategy configuration
assert deserialized_config.deep_crawl_strategy.max_depth == 3
- assert isinstance(deserialized_config.deep_crawl_strategy.filter_chain, FastFilterChain)
- assert isinstance(deserialized_config.deep_crawl_strategy.url_scorer, FastKeywordRelevanceScorer)
+ assert isinstance(deserialized_config.deep_crawl_strategy.filter_chain, FilterChain)
+ assert isinstance(deserialized_config.deep_crawl_strategy.url_scorer, KeywordRelevanceScorer)
print("Serialization cycle test passed successfully!")
if __name__ == "__main__":
- test_config_serialization_cycle()
\ No newline at end of file
+ import subprocess
+
+ sys.exit(subprocess.call(["pytest", *sys.argv[1:], sys.argv[0]]))
diff --git a/tests/docker/test_core.py b/tests/docker/test_core.py
new file mode 100644
index 000000000..9954b55b8
--- /dev/null
+++ b/tests/docker/test_core.py
@@ -0,0 +1,322 @@
+import base64
+import io
+import json
+import os
+import sys
+import time
+from typing import Any, Dict
+
+import pytest
+import requests
+from PIL import Image, ImageFile
+
+from crawl4ai import (
+ CosineStrategy,
+ JsonCssExtractionStrategy,
+ LLMConfig,
+ LLMExtractionStrategy,
+)
+from crawl4ai.async_configs import BrowserConfig
+from crawl4ai.async_webcrawler import CrawlerRunConfig
+
+from .common import async_client
+
+
+class Crawl4AiTester:
+ def __init__(self):
+ self.client = async_client()
+
+ async def submit_and_validate(
+ self, request_data: Dict[str, Any], timeout: int = 300
+ ) -> Dict[str, Any]:
+ """Submit a crawl request and validate for the result.
+
+ The response is validated to ensure that it is successful and contains at least one result.
+
+ :param request_data: The request data to submit.
+ :type request_data: Dict[str, Any]
+ :param timeout: The maximum time to wait for the response, defaults to 300.
+ :type timeout: int, optional
+ :return: The response of the crawl decoded from JSON.
+ :rtype: Dict[str, Any]
+ """
+ response = await self.client.post("/crawl", json=request_data)
+ return self.assert_valid_response(response.json())
+
+ async def check_health(self) -> None:
+ """Check the health of the service.
+
+ Check the health of the service and wait for it to be ready.
+ If the service is not ready after a few retries, the test will fail."""
+ max_retries = 5
+ for i in range(max_retries):
+ try:
+ health = await self.client.get("/health", timeout=10)
+ print("Health check:", health.json())
+ return
+ except requests.exceptions.RequestException:
+ if i == max_retries - 1:
+ assert False, f"Failed to connect after {max_retries} attempts"
+
+ print(
+ f"Waiting for service to start (attempt {i + 1}/{max_retries})..."
+ )
+ time.sleep(5)
+
+ pytest.fail(f"Failed to connect to service after {max_retries} retries")
+
+ def assert_valid_response(self, result: Dict[str, Any]) -> Dict[str, Any]:
+ """Assert that the response is valid and returns the first result.
+
+ :param result: The response from the API
+ :type result: Dict[str, Any]
+ :return: The first result
+ :rtype: dict[str, Any]
+ """
+ assert result["success"]
+ assert result["results"]
+ assert len(result["results"]) > 0
+ assert result["results"][0]["url"]
+ assert result["results"][0]["html"]
+ return result["results"][0]
+
+
+@pytest.fixture
+def tester() -> Crawl4AiTester:
+ return Crawl4AiTester()
+
+
+@pytest.mark.asyncio
+async def test_basic_crawl(tester: Crawl4AiTester):
+ request = {"urls": ["https://www.nbcnews.com/business"], "priority": 10}
+
+ await tester.submit_and_validate(request)
+
+
+@pytest.mark.asyncio
+async def test_js_execution(tester: Crawl4AiTester):
+ request = {
+ "urls": ["https://www.nbcnews.com/business"],
+ "priority": 8,
+ "js_code": [
+ "const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More')); loadMoreButton && loadMoreButton.click();"
+ ],
+ "wait_for": "article.tease-card:nth-child(10)",
+ "crawler_params": {"headless": True},
+ }
+
+ await tester.submit_and_validate(request)
+
+
+@pytest.mark.asyncio
+async def test_css_selector(tester: Crawl4AiTester):
+ print("\n=== Testing CSS Selector ===")
+ request = {
+ "urls": ["https://www.nbcnews.com/business"],
+ "priority": 7,
+ "css_selector": ".wide-tease-item__description",
+ "crawler_params": {"headless": True},
+ "extra": {"word_count_threshold": 10},
+ }
+
+ await tester.submit_and_validate(request)
+
+
+@pytest.mark.asyncio
+async def test_structured_extraction(tester: Crawl4AiTester):
+ schema = {
+ "name": "Coinbase Crypto Prices",
+ "baseSelector": "table > tbody > tr",
+ "fields": [
+ {
+ "name": "crypto",
+ "selector": "td:nth-child(1) h2",
+ "type": "text",
+ },
+ {
+ "name": "symbol",
+ "selector": "td:nth-child(1) p",
+ "type": "text",
+ },
+ {
+ "name": "price",
+ "selector": "td:nth-child(2)",
+ "type": "text",
+ },
+ ],
+ }
+ crawler_config: CrawlerRunConfig = CrawlerRunConfig(
+ extraction_strategy=JsonCssExtractionStrategy(schema),
+ )
+ request = {
+ "urls": ["https://www.coinbase.com/explore"],
+ "priority": 9,
+ "crawler_config": crawler_config.dump(),
+ }
+
+ result = await tester.submit_and_validate(request)
+
+ extracted = json.loads(result["extracted_content"])
+ print(f"Extracted {len(extracted)} items")
+ print("Sample item:", json.dumps(extracted[0], indent=2))
+
+ assert len(extracted) > 0
+
+
+@pytest.mark.asyncio
+async def test_llm_extraction(tester: Crawl4AiTester):
+ if not os.getenv("OPENAI_API_KEY"):
+ pytest.skip("OPENAI_API_KEY not set")
+
+ schema = {
+ "type": "object",
+ "properties": {
+ "model_name": {
+ "type": "string",
+ "description": "Name of the OpenAI model.",
+ },
+ "input_fee": {
+ "type": "string",
+ "description": "Fee for input token for the OpenAI model.",
+ },
+ "output_fee": {
+ "type": "string",
+ "description": "Fee for output token for the OpenAI model.",
+ },
+ },
+ "required": ["model_name", "input_fee", "output_fee"],
+ }
+
+ crawler_config: CrawlerRunConfig = CrawlerRunConfig(
+ extraction_strategy=LLMExtractionStrategy(
+ schema=schema,
+ llm_config=LLMConfig(
+ provider="openai/gpt-4o-mini",
+ api_token=os.getenv("OPENAI_API_KEY"),
+ ),
+ extraction_type="schema",
+ instruction="From the crawled content, extract all mentioned model names along with their fees for input and output tokens.",
+ ),
+ word_count_threshold=1,
+ )
+ request = {
+ "urls": ["https://openai.com/api/pricing"],
+ "priority": 8,
+ "crawler_config": crawler_config.dump(),
+ }
+
+ result = await tester.submit_and_validate(request)
+ extracted = json.loads(result["extracted_content"])
+ print(f"Extracted {len(extracted)} model pricing entries")
+ print("Sample entry:", json.dumps(extracted[0], indent=2))
+ assert result["success"]
+ assert len(extracted) > 0, "No model pricing entries found"
+
+
+@pytest.mark.asyncio
+@pytest.mark.timeout(30)
+async def test_llm_with_ollama(tester: Crawl4AiTester):
+ print("\n=== Testing LLM with Ollama ===")
+ schema = {
+ "type": "object",
+ "properties": {
+ "article_title": {
+ "type": "string",
+ "description": "The main title of the news article",
+ },
+ "summary": {
+ "type": "string",
+ "description": "A brief summary of the article content",
+ },
+ "main_topics": {
+ "type": "array",
+ "items": {"type": "string"},
+ "description": "Main topics or themes discussed in the article",
+ },
+ },
+ }
+
+ crawler_config: CrawlerRunConfig = CrawlerRunConfig(
+ extraction_strategy=LLMExtractionStrategy(
+ schema=schema,
+ llm_config=LLMConfig(
+ provider="ollama/llama2",
+ ),
+ extraction_type="schema",
+ instruction="Extract the main article information including title, summary, and main topics.",
+ ),
+ word_count_threshold=1,
+ )
+
+ request = {
+ "urls": ["https://www.nbcnews.com/business"],
+ "priority": 8,
+ "crawler_config": crawler_config.dump(),
+ }
+
+ result = await tester.submit_and_validate(request)
+ extracted = json.loads(result["extracted_content"])
+ print("Extracted content:", json.dumps(extracted, indent=2))
+ assert result["success"]
+ assert len(extracted) > 0, "No content extracted"
+
+
+@pytest.mark.asyncio
+async def test_cosine_extraction(tester: Crawl4AiTester):
+ print("\n=== Testing Cosine Extraction ===")
+ crawler_config: CrawlerRunConfig = CrawlerRunConfig(
+ extraction_strategy=CosineStrategy(
+ semantic_filter="business finance economy",
+ word_count_threshold=10,
+ max_dist=0.2,
+ top_k=3,
+ ),
+ word_count_threshold=1,
+ verbose=True,
+ )
+
+ request = {
+ "urls": ["https://www.nbcnews.com/business"],
+ "priority": 8,
+ "crawler_config": crawler_config.dump(),
+ }
+
+ result = await tester.submit_and_validate(request)
+ extracted = json.loads(result["extracted_content"])
+ print(f"Extracted {len(extracted)} text clusters")
+ print("First cluster tags:", extracted[0]["tags"])
+ assert result["success"]
+ assert len(extracted) > 0, "No clusters found"
+
+
+@pytest.mark.asyncio
+async def test_screenshot(tester: Crawl4AiTester):
+ crawler_config: CrawlerRunConfig = CrawlerRunConfig(
+ screenshot=True,
+ word_count_threshold=1,
+ verbose=True,
+ )
+
+ browser_config: BrowserConfig = BrowserConfig(headless=True)
+
+ request = {
+ "urls": ["https://www.nbcnews.com/business"],
+ "priority": 5,
+ "screenshot": True,
+ "crawler_config": crawler_config.dump(),
+ "browser_config": browser_config.dump(),
+ }
+
+ result = await tester.submit_and_validate(request)
+ assert result.get("screenshot")
+ image: ImageFile.ImageFile = Image.open(
+ io.BytesIO(base64.b64decode(result["screenshot"]))
+ )
+
+ assert image.format == "BMP"
+
+
+if __name__ == "__main__":
+ import subprocess
+
+ sys.exit(subprocess.call(["pytest", *sys.argv[1:], sys.argv[0]]))
diff --git a/tests/docker/test_crawl_task.py b/tests/docker/test_crawl_task.py
new file mode 100644
index 000000000..596f1d2aa
--- /dev/null
+++ b/tests/docker/test_crawl_task.py
@@ -0,0 +1,275 @@
+import asyncio
+import json
+import os
+import sys
+import time
+from typing import Any, Dict
+
+import pytest
+from httpx import AsyncClient, Response, codes
+
+from crawl4ai.async_configs import CrawlerRunConfig
+
+from .common import async_client
+
+
+class NBCNewsAPITest:
+ def __init__(self):
+ self.client: AsyncClient = async_client()
+
+ async def submit_crawl(self, request_data: Dict[str, Any]) -> str:
+ if "crawler_config" not in request_data:
+ request_data["crawler_config"] = CrawlerRunConfig(stream=True).dump()
+
+ response: Response = await self.client.post("/crawl/stream", json=request_data)
+ assert response.status_code == codes.OK, f"Error: {response.text}"
+ result = response.json()
+ assert "task_id" in result
+
+ return result["task_id"]
+
+ async def get_task_status(self, task_id: str) -> Dict[str, Any]:
+ response: Response = await self.client.get(f"/task/{task_id}")
+ assert response.status_code == codes.OK
+ result = response.json()
+ assert "status" in result
+ return result
+
+ async def wait_for_task(
+ self, task_id: str, timeout: int = 300, poll_interval: int = 2
+ ) -> Dict[str, Any]:
+ start_time = time.time()
+ while True:
+ if time.time() - start_time > timeout:
+ raise TimeoutError(
+ f"Task {task_id} did not complete within {timeout} seconds"
+ )
+
+ status = await self.get_task_status(task_id)
+ if status["status"] in ["completed", "failed"]:
+ return status
+
+ await asyncio.sleep(poll_interval)
+
+ async def check_health(self) -> Dict[str, Any]:
+ response: Response = await self.client.get("/health")
+ assert response.status_code == codes.OK
+ return response.json()
+
+
+@pytest.fixture
+def api() -> NBCNewsAPITest:
+ return NBCNewsAPITest()
+
+
+@pytest.mark.asyncio
+@pytest.mark.skip("Crawl with task_id not implemented yet")
+async def test_basic_crawl(api: NBCNewsAPITest):
+ print("\n=== Testing Basic Crawl ===")
+ request = {"urls": ["https://www.nbcnews.com/business"], "priority": 10}
+ task_id = await api.submit_crawl(request)
+ result = await api.wait_for_task(task_id)
+ print(f"Basic crawl result length: {len(result['result']['markdown'])}")
+ assert result["status"] == "completed"
+ assert "result" in result
+ assert result["result"]["success"]
+
+
+@pytest.mark.asyncio
+@pytest.mark.skip("Crawl with task_id not implemented yet")
+async def test_js_execution(api: NBCNewsAPITest):
+ print("\n=== Testing JS Execution ===")
+ request = {
+ "urls": ["https://www.nbcnews.com/business"],
+ "priority": 8,
+ "js_code": [
+ "const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More')); loadMoreButton && loadMoreButton.click();"
+ ],
+ "wait_for": "article.tease-card:nth-child(10)",
+ "crawler_params": {"headless": True},
+ }
+ task_id = await api.submit_crawl(request)
+ result = await api.wait_for_task(task_id)
+ print(f"JS execution result length: {len(result['result']['markdown'])}")
+ assert result["status"] == "completed"
+ assert result["result"]["success"]
+
+
+@pytest.mark.asyncio
+@pytest.mark.skip("Crawl with task_id not implemented yet")
+async def test_css_selector(api: NBCNewsAPITest):
+ print("\n=== Testing CSS Selector ===")
+ request = {
+ "urls": ["https://www.nbcnews.com/business"],
+ "priority": 7,
+ "css_selector": ".wide-tease-item__description",
+ }
+ task_id = await api.submit_crawl(request)
+ result = await api.wait_for_task(task_id)
+ print(f"CSS selector result length: {len(result['result']['markdown'])}")
+ assert result["status"] == "completed"
+ assert result["result"]["success"]
+
+
+@pytest.mark.asyncio
+@pytest.mark.skip("Crawl with task_id not implemented yet")
+async def test_structured_extraction(api: NBCNewsAPITest):
+ print("\n=== Testing Structured Extraction ===")
+ schema = {
+ "name": "NBC News Articles",
+ "baseSelector": "article.tease-card",
+ "fields": [
+ {"name": "title", "selector": "h2", "type": "text"},
+ {
+ "name": "description",
+ "selector": ".tease-card__description",
+ "type": "text",
+ },
+ {
+ "name": "link",
+ "selector": "a",
+ "type": "attribute",
+ "attribute": "href",
+ },
+ ],
+ }
+
+ request = {
+ "urls": ["https://www.nbcnews.com/business"],
+ "priority": 9,
+ "extraction_config": {"type": "json_css", "params": {"schema": schema}},
+ }
+ task_id = await api.submit_crawl(request)
+ result = await api.wait_for_task(task_id)
+ extracted = json.loads(result["result"]["extracted_content"])
+ print(f"Extracted {len(extracted)} articles")
+ assert result["status"] == "completed"
+ assert result["result"]["success"]
+ assert len(extracted) > 0
+
+
+@pytest.mark.asyncio
+@pytest.mark.skip("Crawl with task_id not implemented yet")
+async def test_batch_crawl(api: NBCNewsAPITest):
+ print("\n=== Testing Batch Crawl ===")
+ request = {
+ "urls": [
+ "https://www.nbcnews.com/business",
+ "https://www.nbcnews.com/business/consumer",
+ "https://www.nbcnews.com/business/economy",
+ ],
+ "priority": 6,
+ "crawler_params": {"headless": True},
+ }
+ task_id = await api.submit_crawl(request)
+ result = await api.wait_for_task(task_id)
+ print(f"Batch crawl completed, got {len(result['results'])} results")
+ assert result["status"] == "completed"
+ assert "results" in result
+ assert len(result["results"]) == 3
+
+
+@pytest.mark.asyncio
+@pytest.mark.skip("Crawl with task_id not implemented yet")
+async def test_llm_extraction(api: NBCNewsAPITest):
+ print("\n=== Testing LLM Extraction with Ollama ===")
+ schema = {
+ "type": "object",
+ "properties": {
+ "article_title": {
+ "type": "string",
+ "description": "The main title of the news article",
+ },
+ "summary": {
+ "type": "string",
+ "description": "A brief summary of the article content",
+ },
+ "main_topics": {
+ "type": "array",
+ "items": {"type": "string"},
+ "description": "Main topics or themes discussed in the article",
+ },
+ },
+ "required": ["article_title", "summary", "main_topics"],
+ }
+
+ request = {
+ "urls": ["https://www.nbcnews.com/business"],
+ "priority": 8,
+ "extraction_config": {
+ "type": "llm",
+ "params": {
+ "provider": "openai/gpt-4o-mini",
+ "api_key": os.getenv("OLLAMA_API_KEY"),
+ "schema": schema,
+ "extraction_type": "schema",
+ "instruction": """Extract the main article information including title, a brief summary, and main topics discussed.
+ Focus on the primary business news article on the page.""",
+ },
+ },
+ "crawler_params": {"headless": True, "word_count_threshold": 1},
+ }
+
+ task_id = await api.submit_crawl(request)
+ result = await api.wait_for_task(task_id)
+
+ if result["status"] == "completed":
+ extracted = json.loads(result["result"]["extracted_content"])
+ print("Extracted article analysis:")
+ print(json.dumps(extracted, indent=2))
+
+ assert result["status"] == "completed"
+ assert result["result"]["success"]
+
+
+@pytest.mark.asyncio
+@pytest.mark.skip("Crawl with task_id not implemented yet")
+async def test_screenshot(api: NBCNewsAPITest):
+ print("\n=== Testing Screenshot ===")
+ request = {
+ "urls": ["https://www.nbcnews.com/business"],
+ "priority": 5,
+ "screenshot": True,
+ "crawler_params": {"headless": True},
+ }
+ task_id = await api.submit_crawl(request)
+ result = await api.wait_for_task(task_id)
+ print("Screenshot captured:", bool(result["result"]["screenshot"]))
+ assert result["status"] == "completed"
+ assert result["result"]["success"]
+ assert result["result"]["screenshot"] is not None
+
+
+@pytest.mark.asyncio
+@pytest.mark.skip("Crawl with task_id not implemented yet")
+async def test_priority_handling(api: NBCNewsAPITest):
+ print("\n=== Testing Priority Handling ===")
+ # Submit low priority task first
+ low_priority = {
+ "urls": ["https://www.nbcnews.com/business"],
+ "priority": 1,
+ "crawler_params": {"headless": True},
+ }
+ low_task_id = await api.submit_crawl(low_priority)
+
+ # Submit high priority task
+ high_priority = {
+ "urls": ["https://www.nbcnews.com/business/consumer"],
+ "priority": 10,
+ "crawler_params": {"headless": True},
+ }
+ high_task_id = await api.submit_crawl(high_priority)
+
+ # Get both results
+ high_result = await api.wait_for_task(high_task_id)
+ low_result = await api.wait_for_task(low_task_id)
+
+ print("Both tasks completed")
+ assert high_result["status"] == "completed"
+ assert low_result["status"] == "completed"
+
+
+if __name__ == "__main__":
+ import subprocess
+
+ sys.exit(subprocess.call(["pytest", *sys.argv[1:], sys.argv[0]]))
diff --git a/tests/docker/test_docker.py b/tests/docker/test_docker.py
index cf95671ea..f64355711 100644
--- a/tests/docker/test_docker.py
+++ b/tests/docker/test_docker.py
@@ -1,58 +1,25 @@
-import requests
-import time
-import httpx
-import asyncio
-from typing import Dict, Any
+import sys
+
+from httpx import codes
+import pytest
+
from crawl4ai import (
BrowserConfig, CrawlerRunConfig, DefaultMarkdownGenerator,
PruningContentFilter, JsonCssExtractionStrategy, LLMContentFilter, CacheMode
)
-from crawl4ai import LLMConfig
-from crawl4ai.docker_client import Crawl4aiDockerClient
-
-class Crawl4AiTester:
- def __init__(self, base_url: str = "http://localhost:11235"):
- self.base_url = base_url
-
- def submit_and_wait(
- self, request_data: Dict[str, Any], timeout: int = 300
- ) -> Dict[str, Any]:
- # Submit crawl job
- response = requests.post(f"{self.base_url}/crawl", json=request_data)
- task_id = response.json()["task_id"]
- print(f"Task ID: {task_id}")
-
- # Poll for result
- start_time = time.time()
- while True:
- if time.time() - start_time > timeout:
- raise TimeoutError(
- f"Task {task_id} did not complete within {timeout} seconds"
- )
+from crawl4ai.async_configs import LLMConfig
- result = requests.get(f"{self.base_url}/task/{task_id}")
- status = result.json()
+from .common import async_client, docker_client
- if status["status"] == "failed":
- print("Task failed:", status.get("error"))
- raise Exception(f"Task failed: {status.get('error')}")
- if status["status"] == "completed":
- return status
+@pytest.fixture
+def browser_config() -> BrowserConfig:
+ return BrowserConfig(headless=True, viewport_width=1200, viewport_height=800)
- time.sleep(2)
-async def test_direct_api():
- """Test direct API endpoints without using the client SDK"""
- print("\n=== Testing Direct API Calls ===")
-
- # Test 1: Basic crawl with content filtering
- browser_config = BrowserConfig(
- headless=True,
- viewport_width=1200,
- viewport_height=800
- )
-
+@pytest.mark.asyncio
+async def test_direct_filtering(browser_config: BrowserConfig):
+ """Direct request with content filtering."""
crawler_config = CrawlerRunConfig(
cache_mode=CacheMode.BYPASS,
markdown_generator=DefaultMarkdownGenerator(
@@ -72,22 +39,25 @@ async def test_direct_api():
}
# Make direct API call
- async with httpx.AsyncClient() as client:
+ async with async_client() as client:
response = await client.post(
- "http://localhost:8000/crawl",
+ "/crawl",
json=request_data,
timeout=300
)
- assert response.status_code == 200
+ assert response.status_code == codes.OK
result = response.json()
- print("Basic crawl result:", result["success"])
+ assert result["success"]
+
- # Test 2: Structured extraction with JSON CSS
+@pytest.mark.asyncio
+async def test_direct_structured_extraction(browser_config: BrowserConfig):
+ """Direct request using structured extraction with JSON CSS."""
schema = {
- "baseSelector": "article.post",
+ "baseSelector": "body > div",
"fields": [
{"name": "title", "selector": "h1", "type": "text"},
- {"name": "content", "selector": ".content", "type": "html"}
+ {"name": "content", "selector": "p", "type": "html"}
]
}
@@ -96,30 +66,50 @@ async def test_direct_api():
extraction_strategy=JsonCssExtractionStrategy(schema=schema)
)
- request_data["crawler_config"] = crawler_config.dump()
+ request_data = {
+ "urls": ["https://example.com"],
+ "browser_config": browser_config.dump(),
+ "crawler_config": crawler_config.dump()
+ }
- async with httpx.AsyncClient() as client:
+ async with async_client() as client:
response = await client.post(
- "http://localhost:8000/crawl",
+ "/crawl",
json=request_data
)
- assert response.status_code == 200
+ assert response.status_code == codes.OK
result = response.json()
- print("Structured extraction result:", result["success"])
+ assert result["success"]
+ assert result["results"]
+ assert len(result["results"]) == 1
+ assert "extracted_content" in result["results"][0]
+ assert (
+ result["results"][0]["extracted_content"]
+ == """[
+ {
+ "title": "Example Domain",
+ "content": "
This domain is for use in illustrative examples in documents. You may use this\\n domain in literature without prior coordination or asking for permission.
"
+ }
+]"""
+ )
+
+
+@pytest.mark.asyncio
+async def test_direct_schema(browser_config: BrowserConfig):
+ """Get the schema."""
+ async with async_client() as client:
+ response = await client.get("/schema")
+ assert response.status_code == codes.OK
+ schemas = response.json()
+ assert schemas
+ assert len(schemas.keys()) == 2
+ print("Retrieved schemas for:", list(schemas.keys()))
- # Test 3: Get schema
- # async with httpx.AsyncClient() as client:
- # response = await client.get("http://localhost:8000/schema")
- # assert response.status_code == 200
- # schemas = response.json()
- # print("Retrieved schemas for:", list(schemas.keys()))
-async def test_with_client():
+@pytest.mark.asyncio
+async def test_with_client_basic():
"""Test using the Crawl4AI Docker client SDK"""
- print("\n=== Testing Client SDK ===")
-
- async with Crawl4aiDockerClient(verbose=True) as client:
- # Test 1: Basic crawl
+ async with docker_client() as client:
browser_config = BrowserConfig(headless=True)
crawler_config = CrawlerRunConfig(
cache_mode=CacheMode.BYPASS,
@@ -128,17 +118,23 @@ async def test_with_client():
threshold=0.48,
threshold_type="fixed"
)
- )
+ ),
+ stream=False,
)
+ await client.authenticate("test@example.com")
result = await client.crawl(
urls=["https://example.com"],
browser_config=browser_config,
crawler_config=crawler_config
)
- print("Client SDK basic crawl:", result.success)
+ assert result.success
+
- # Test 2: LLM extraction with streaming
+@pytest.mark.asyncio
+async def test_with_client_llm_streaming():
+ async with docker_client() as client:
+ browser_config = BrowserConfig(headless=True)
crawler_config = CrawlerRunConfig(
cache_mode=CacheMode.BYPASS,
markdown_generator=DefaultMarkdownGenerator(
@@ -150,26 +146,24 @@ async def test_with_client():
stream=True
)
+ await client.authenticate("test@example.com")
async for result in await client.crawl(
urls=["https://example.com"],
browser_config=browser_config,
crawler_config=crawler_config
):
- print(f"Streaming result for: {result.url}")
+ assert result.success, f"Stream failed with: {result.error_message}"
- # # Test 3: Get schema
- # schemas = await client.get_schema()
- # print("Retrieved client schemas for:", list(schemas.keys()))
-async def main():
- """Run all tests"""
- # Test direct API
- print("Testing direct API calls...")
- await test_direct_api()
+@pytest.mark.asyncio
+async def test_with_client_get_schema():
+ async with docker_client() as client:
+ await client.authenticate("test@example.com")
+ schemas = await client.get_schema()
+ print("Retrieved client schemas for:", list(schemas.keys()))
- # Test client SDK
- print("\nTesting client SDK...")
- await test_with_client()
if __name__ == "__main__":
- asyncio.run(main())
\ No newline at end of file
+ import subprocess
+
+ sys.exit(subprocess.call(["pytest", *sys.argv[1:], sys.argv[0]]))
diff --git a/tests/docker/test_dockerclient.py b/tests/docker/test_dockerclient.py
index cba6c4c9c..9b7eaa691 100644
--- a/tests/docker/test_dockerclient.py
+++ b/tests/docker/test_dockerclient.py
@@ -1,34 +1,50 @@
-import asyncio
-from crawl4ai.docker_client import Crawl4aiDockerClient
-from crawl4ai import (
- BrowserConfig,
- CrawlerRunConfig
-)
-
-async def main():
- async with Crawl4aiDockerClient(base_url="http://localhost:8000", verbose=True) as client:
+import sys
+
+import pytest
+
+from crawl4ai import BrowserConfig, CrawlerRunConfig
+
+from .common import docker_client
+
+
+@pytest.mark.asyncio
+async def test_non_streaming():
+ async with docker_client() as client:
await client.authenticate("test@example.com")
-
+
# Non-streaming crawl
results = await client.crawl(
["https://example.com", "https://python.org"],
browser_config=BrowserConfig(headless=True),
crawler_config=CrawlerRunConfig()
)
- print(f"Non-streaming results: {results}")
-
+ assert results
+ for result in results:
+ assert result.success
+ print(f"Non-streamed result: {result}")
+
+
+@pytest.mark.asyncio
+async def test_streaming():
+ async with docker_client() as client:
# Streaming crawl
crawler_config = CrawlerRunConfig(stream=True)
+ await client.authenticate("user@example.com")
async for result in await client.crawl(
["https://example.com", "https://python.org"],
browser_config=BrowserConfig(headless=True),
crawler_config=crawler_config
):
+ assert result.success
print(f"Streamed result: {result}")
-
+
# Get schema
schema = await client.get_schema()
+ assert schema
print(f"Schema: {schema}")
+
if __name__ == "__main__":
- asyncio.run(main())
\ No newline at end of file
+ import subprocess
+
+ sys.exit(subprocess.call(["pytest", *sys.argv[1:], sys.argv[0]]))
diff --git a/tests/docker/test_serialization.py b/tests/docker/test_serialization.py
index 6ce800059..93faa3e98 100644
--- a/tests/docker/test_serialization.py
+++ b/tests/docker/test_serialization.py
@@ -1,110 +1,24 @@
-import inspect
-from typing import Any, Dict
-from enum import Enum
-
-from crawl4ai import LLMConfig
-
-def to_serializable_dict(obj: Any) -> Dict:
- """
- Recursively convert an object to a serializable dictionary using {type, params} structure
- for complex objects.
- """
- if obj is None:
- return None
-
- # Handle basic types
- if isinstance(obj, (str, int, float, bool)):
- return obj
-
- # Handle Enum
- if isinstance(obj, Enum):
- return {
- "type": obj.__class__.__name__,
- "params": obj.value
- }
-
- # Handle datetime objects
- if hasattr(obj, 'isoformat'):
- return obj.isoformat()
-
- # Handle lists, tuples, and sets
- if isinstance(obj, (list, tuple, set)):
- return [to_serializable_dict(item) for item in obj]
-
- # Handle dictionaries - preserve them as-is
- if isinstance(obj, dict):
- return {
- "type": "dict", # Mark as plain dictionary
- "value": {str(k): to_serializable_dict(v) for k, v in obj.items()}
- }
-
- # Handle class instances
- if hasattr(obj, '__class__'):
- # Get constructor signature
- sig = inspect.signature(obj.__class__.__init__)
- params = sig.parameters
-
- # Get current values
- current_values = {}
- for name, param in params.items():
- if name == 'self':
- continue
-
- value = getattr(obj, name, param.default)
-
- # Only include if different from default, considering empty values
- if not (is_empty_value(value) and is_empty_value(param.default)):
- if value != param.default:
- current_values[name] = to_serializable_dict(value)
-
- return {
- "type": obj.__class__.__name__,
- "params": current_values
- }
-
- return str(obj)
-
-def from_serializable_dict(data: Any) -> Any:
- """
- Recursively convert a serializable dictionary back to an object instance.
- """
- if data is None:
- return None
-
- # Handle basic types
- if isinstance(data, (str, int, float, bool)):
- return data
-
- # Handle typed data
- if isinstance(data, dict) and "type" in data:
- # Handle plain dictionaries
- if data["type"] == "dict":
- return {k: from_serializable_dict(v) for k, v in data["value"].items()}
-
- # Import from crawl4ai for class instances
- import crawl4ai
- cls = getattr(crawl4ai, data["type"])
-
- # Handle Enum
- if issubclass(cls, Enum):
- return cls(data["params"])
-
- # Handle class instances
- constructor_args = {
- k: from_serializable_dict(v) for k, v in data["params"].items()
- }
- return cls(**constructor_args)
-
- # Handle lists
- if isinstance(data, list):
- return [from_serializable_dict(item) for item in data]
-
- # Handle raw dictionaries (legacy support)
- if isinstance(data, dict):
- return {k: from_serializable_dict(v) for k, v in data.items()}
-
- return data
-
+import sys
+from typing import Any, List
+
+import pytest
+from _pytest.mark import ParameterSet
+
+from crawl4ai import (
+ BM25ContentFilter,
+ CacheMode,
+ CrawlerRunConfig,
+ DefaultMarkdownGenerator,
+ JsonCssExtractionStrategy,
+ LLMContentFilter,
+ LXMLWebScrapingStrategy,
+ PruningContentFilter,
+ RegexChunking,
+ WebScrapingStrategy,
+)
+from crawl4ai.async_configs import LLMConfig, from_serializable, to_serializable_dict
+
+
def is_empty_value(value: Any) -> bool:
"""Check if a value is effectively empty/null."""
if value is None:
@@ -113,90 +27,27 @@ def is_empty_value(value: Any) -> bool:
return True
return False
-# if __name__ == "__main__":
-# from crawl4ai import (
-# CrawlerRunConfig, CacheMode, DefaultMarkdownGenerator,
-# PruningContentFilter, BM25ContentFilter, LLMContentFilter,
-# JsonCssExtractionStrategy, CosineStrategy, RegexChunking,
-# WebScrapingStrategy, LXMLWebScrapingStrategy
-# )
-
-# # Test Case 1: BM25 content filtering through markdown generator
-# config1 = CrawlerRunConfig(
-# cache_mode=CacheMode.BYPASS,
-# markdown_generator=DefaultMarkdownGenerator(
-# content_filter=BM25ContentFilter(
-# user_query="technology articles",
-# bm25_threshold=1.2,
-# language="english"
-# )
-# ),
-# chunking_strategy=RegexChunking(patterns=[r"\n\n", r"\.\s+"]),
-# excluded_tags=["nav", "footer", "aside"],
-# remove_overlay_elements=True
-# )
-
-# # Serialize
-# serialized = to_serializable_dict(config1)
-# print("\nSerialized Config:")
-# print(serialized)
-
-# # Example output structure would now look like:
-# """
-# {
-# "type": "CrawlerRunConfig",
-# "params": {
-# "cache_mode": {
-# "type": "CacheMode",
-# "params": "bypass"
-# },
-# "markdown_generator": {
-# "type": "DefaultMarkdownGenerator",
-# "params": {
-# "content_filter": {
-# "type": "BM25ContentFilter",
-# "params": {
-# "user_query": "technology articles",
-# "bm25_threshold": 1.2,
-# "language": "english"
-# }
-# }
-# }
-# }
-# }
-# }
-# """
-
-# # Deserialize
-# deserialized = from_serializable_dict(serialized)
-# print("\nDeserialized Config:")
-# print(to_serializable_dict(deserialized))
-
-# # Verify they match
-# assert to_serializable_dict(config1) == to_serializable_dict(deserialized)
-# print("\nVerification passed: Configuration matches after serialization/deserialization!")
-
-if __name__ == "__main__":
- from crawl4ai import (
- CrawlerRunConfig, CacheMode, DefaultMarkdownGenerator,
- PruningContentFilter, BM25ContentFilter, LLMContentFilter,
- JsonCssExtractionStrategy, RegexChunking,
- WebScrapingStrategy, LXMLWebScrapingStrategy
- )
+def config_params() -> List[ParameterSet]:
# Test Case 1: BM25 content filtering through markdown generator
- config1 = CrawlerRunConfig(
- cache_mode=CacheMode.BYPASS,
- markdown_generator=DefaultMarkdownGenerator(
- content_filter=BM25ContentFilter(
- user_query="technology articles",
- bm25_threshold=1.2,
- language="english"
- )
- ),
- chunking_strategy=RegexChunking(patterns=[r"\n\n", r"\.\s+"]),
- excluded_tags=["nav", "footer", "aside"],
- remove_overlay_elements=True
+ params: List[ParameterSet] = []
+ params.append(
+ pytest.param(
+ CrawlerRunConfig(
+ cache_mode=CacheMode.BYPASS,
+ markdown_generator=DefaultMarkdownGenerator(
+ content_filter=BM25ContentFilter(
+ user_query="technology articles",
+ bm25_threshold=1.2,
+ language="english",
+ )
+ ),
+ chunking_strategy=RegexChunking(patterns=[r"\n\n", r"\.\s+"]),
+ excluded_tags=["nav", "footer", "aside"],
+ remove_overlay_elements=True,
+ ),
+ id="BM25 Content Filter",
+ )
)
# Test Case 2: LLM-based extraction with pruning filter
@@ -204,52 +55,64 @@ def is_empty_value(value: Any) -> bool:
"baseSelector": "article.post",
"fields": [
{"name": "title", "selector": "h1", "type": "text"},
- {"name": "content", "selector": ".content", "type": "html"}
- ]
+ {"name": "content", "selector": ".content", "type": "html"},
+ ],
}
- config2 = CrawlerRunConfig(
- extraction_strategy=JsonCssExtractionStrategy(schema=schema),
- markdown_generator=DefaultMarkdownGenerator(
- content_filter=PruningContentFilter(
- threshold=0.48,
- threshold_type="fixed",
- min_word_threshold=0
+ params.append(
+ pytest.param(
+ CrawlerRunConfig(
+ extraction_strategy=JsonCssExtractionStrategy(schema=schema),
+ markdown_generator=DefaultMarkdownGenerator(
+ content_filter=PruningContentFilter(
+ threshold=0.48, threshold_type="fixed", min_word_threshold=0
+ ),
+ options={"ignore_links": True},
+ ),
+ scraping_strategy=LXMLWebScrapingStrategy(),
),
- options={"ignore_links": True}
- ),
- scraping_strategy=LXMLWebScrapingStrategy()
+ id="LLM Pruning Filter",
+ )
)
# Test Case 3:LLM content filter
- config3 = CrawlerRunConfig(
- markdown_generator=DefaultMarkdownGenerator(
- content_filter=LLMContentFilter(
- llm_config = LLMConfig(provider="openai/gpt-4"),
- instruction="Extract key technical concepts",
- chunk_token_threshold=2000,
- overlap_rate=0.1
+ params.append(
+ pytest.param(
+ CrawlerRunConfig(
+ markdown_generator=DefaultMarkdownGenerator(
+ content_filter=LLMContentFilter(
+ llm_config=LLMConfig(provider="openai/gpt-4"),
+ instruction="Extract key technical concepts",
+ chunk_token_threshold=2000,
+ overlap_rate=0.1,
+ ),
+ options={"ignore_images": True},
+ ),
+ scraping_strategy=WebScrapingStrategy(),
),
- options={"ignore_images": True}
- ),
- scraping_strategy=WebScrapingStrategy()
+ id="LLM Content Filter",
+ )
)
- # Test all configurations
- test_configs = [config1, config2, config3]
-
- for i, config in enumerate(test_configs, 1):
- print(f"\nTesting Configuration {i}:")
-
- # Serialize
- serialized = to_serializable_dict(config)
- print(f"\nSerialized Config {i}:")
- print(serialized)
-
- # Deserialize
- deserialized = from_serializable_dict(serialized)
- print(f"\nDeserialized Config {i}:")
- print(to_serializable_dict(deserialized)) # Convert back to dict for comparison
-
- # Verify they match
- assert to_serializable_dict(config) == to_serializable_dict(deserialized)
- print(f"\nVerification passed: Configuration {i} matches after serialization/deserialization!")
\ No newline at end of file
+ return params
+
+
+@pytest.mark.parametrize("config", config_params())
+def test_serialization(config: CrawlerRunConfig) -> None:
+ # Serialize
+ serialized = to_serializable_dict(config)
+ print("\nSerialized Config:")
+ print(serialized)
+
+ # Deserialize
+ deserialized = from_serializable(serialized)
+ print("\nDeserialized Config:")
+ print(to_serializable_dict(deserialized)) # Convert back to dict for comparison
+
+ # Verify they match
+ assert to_serializable_dict(config) == to_serializable_dict(deserialized)
+
+
+if __name__ == "__main__":
+ import subprocess
+
+ sys.exit(subprocess.call(["pytest", *sys.argv[1:], sys.argv[0]]))
diff --git a/tests/docker/test_server.py b/tests/docker/test_server.py
index 7bb0195bc..59dfd8130 100644
--- a/tests/docker/test_server.py
+++ b/tests/docker/test_server.py
@@ -1,146 +1,136 @@
import asyncio
import json
-from typing import Optional
+import sys
+from typing import Optional, Union
from urllib.parse import quote
-async def test_endpoint(
- endpoint: str,
- url: str,
- params: Optional[dict] = None,
- expected_status: int = 200
-) -> None:
+import aiohttp
+import pytest
+from httpx import Response, codes
+
+from .common import TEST_URLS, async_client, markdown_params
+
+EndpointResponse = Optional[Union[dict, str]]
+
+
+async def endpoint(
+ endpoint: str, url: str, params: Optional[dict] = None, expected_status: int = codes.OK
+) -> EndpointResponse:
"""Test an endpoint and print results"""
- import aiohttp
-
params = params or {}
param_str = "&".join(f"{k}={v}" for k, v in params.items())
- full_url = f"http://localhost:8000/{endpoint}/{quote(url)}"
+ path = f"/{endpoint}/{quote(url)}"
if param_str:
- full_url += f"?{param_str}"
-
- print(f"\nTesting: {full_url}")
-
- try:
- async with aiohttp.ClientSession() as session:
- async with session.get(full_url) as response:
- status = response.status
- try:
- data = await response.json()
- except:
- data = await response.text()
-
- print(f"Status: {status} (Expected: {expected_status})")
- if isinstance(data, dict):
- print(f"Response: {json.dumps(data, indent=2)}")
- else:
- print(f"Response: {data[:500]}...") # First 500 chars
- assert status == expected_status
- return data
- except Exception as e:
- print(f"Error: {str(e)}")
- return None
-
-async def test_llm_task_completion(task_id: str) -> None:
+ path += f"?{param_str}"
+
+ print(f"\nTesting: {path}")
+
+ async with async_client() as session:
+ response: Response = await session.get(path)
+ content_type: str = response.headers.get(
+ aiohttp.hdrs.CONTENT_TYPE, ""
+ ).lower()
+ data: Union[dict, str] = (
+ response.json() if content_type == "application/json" else response.text
+ )
+
+ print(f"Status: {response.status_code} (Expected: {expected_status})")
+ if isinstance(data, dict):
+ print(f"Response: {json.dumps(data, indent=2)}")
+ else:
+ print(f"Response: {data[:500]}...") # First 500 chars
+ assert response.status_code == expected_status
+ return data
+
+
+async def llm_task_completion(task_id: str) -> Optional[dict]:
"""Poll task until completion"""
for _ in range(10): # Try 10 times
- result = await test_endpoint("llm", task_id)
+ result: EndpointResponse = await endpoint("llm", task_id)
+ assert result, "Failed to process endpoint request"
+ assert isinstance(result, dict), "Expected dict response"
+
if result and result.get("status") in ["completed", "failed"]:
return result
print("Task still processing, waiting 5 seconds...")
await asyncio.sleep(5)
print("Task timed out")
+ return None
-async def run_tests():
- print("Starting API Tests...")
-
- # Test URLs
- urls = [
- "example.com",
- "https://www.python.org",
- "https://news.ycombinator.com/news",
- "https://github.com/trending"
- ]
-
- print("\n=== Testing Markdown Endpoint ===")
- for url in[] : #urls:
- # Test different filter types
- for filter_type in ["raw", "fit", "bm25", "llm"]:
- params = {"f": filter_type}
- if filter_type in ["bm25", "llm"]:
- params["q"] = "extract main content"
-
- # Test with and without cache
- for cache in ["0", "1"]:
- params["c"] = cache
- await test_endpoint("md", url, params)
- await asyncio.sleep(1) # Be nice to the server
-
- print("\n=== Testing LLM Endpoint ===")
- for url in []: # urls:
- # Test basic extraction
- result = await test_endpoint(
- "llm",
- url,
- {"q": "Extract title and main content"}
- )
- if result and "task_id" in result:
- print("\nChecking task completion...")
- await test_llm_task_completion(result["task_id"])
-
- # Test with schema
- schema = {
- "type": "object",
- "properties": {
- "title": {"type": "string"},
- "content": {"type": "string"},
- "links": {"type": "array", "items": {"type": "string"}}
- }
- }
- result = await test_endpoint(
- "llm",
- url,
- {
- "q": "Extract content with links",
- "s": json.dumps(schema),
- "c": "1" # Test with cache
- }
- )
- if result and "task_id" in result:
- print("\nChecking schema task completion...")
- await test_llm_task_completion(result["task_id"])
-
- await asyncio.sleep(2) # Be nice to the server
-
- print("\n=== Testing Error Cases ===")
- # Test invalid URL
- await test_endpoint(
- "md",
- "not_a_real_url",
- expected_status=500
- )
-
- # Test invalid filter type
- await test_endpoint(
- "md",
- "example.com",
- {"f": "invalid"},
- expected_status=422
- )
-
- # Test LLM without query
- await test_endpoint(
- "llm",
- "example.com"
+
+@pytest.mark.asyncio
+@pytest.mark.timeout(60) # LLM tasks can take a while.
+@pytest.mark.parametrize("url,params", markdown_params())
+async def test_markdown_endpoint(url: str, params: dict[str, str]):
+ response: EndpointResponse = await endpoint("md", url, params)
+ assert response, "Failed to process endpoint request"
+ assert isinstance(response, str), "Expected str response"
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("url", TEST_URLS)
+@pytest.mark.skip("LLM endpoint doesn't task based requests yet")
+async def test_llm_endpoint_no_schema(url: str):
+ result: EndpointResponse = await endpoint(
+ "llm", url, {"q": "Extract title and main content"}
)
-
- # Test invalid task ID
- await test_endpoint(
- "llm",
- "llm_invalid_task",
- expected_status=404
+ assert result, "Failed to process endpoint request"
+ assert isinstance(result, dict), "Expected dict response"
+ assert "task_id" in result
+
+ print("\nChecking task completion...")
+ await llm_task_completion(result["task_id"])
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("url", TEST_URLS)
+@pytest.mark.skip("LLM endpoint doesn't task based or schema requests yet")
+async def test_llm_endpoint_schema(url: str):
+ schema = {
+ "type": "object",
+ "properties": {
+ "title": {"type": "string"},
+ "content": {"type": "string"},
+ "links": {"type": "array", "items": {"type": "string"}},
+ },
+ }
+ result: EndpointResponse = await endpoint(
+ "llm",
+ url,
+ {
+ "q": "Extract content with links",
+ "s": json.dumps(schema),
+ "c": "1", # Test with cache
+ },
)
-
- print("\nAll tests completed!")
+ assert result, "Failed to process endpoint request"
+ assert isinstance(result, dict), "Expected dict response"
+ assert "task_id" in result
+ print("\nChecking schema task completion...")
+ await llm_task_completion(result["task_id"])
+
+
+@pytest.mark.asyncio
+async def test_invalid_url():
+ await endpoint("md", "not_a_real_url", expected_status=codes.INTERNAL_SERVER_ERROR)
+
+
+@pytest.mark.asyncio
+async def test_invalid_filter():
+ await endpoint("md", "example.com", {"f": "invalid"}, expected_status=codes.UNPROCESSABLE_ENTITY)
+
+
+@pytest.mark.asyncio
+async def test_llm_without_query():
+ await endpoint("llm", "example.com", expected_status=codes.BAD_REQUEST)
+
+
+@pytest.mark.asyncio
+async def test_invalid_task():
+ await endpoint("llm", "llm_invalid_task", expected_status=codes.BAD_REQUEST)
+
if __name__ == "__main__":
- asyncio.run(run_tests())
\ No newline at end of file
+ import subprocess
+
+ sys.exit(subprocess.call(["pytest", *sys.argv[1:], sys.argv[0]]))
diff --git a/tests/docker/test_server_token.py b/tests/docker/test_server_token.py
index 220b6ca2c..0892866ac 100644
--- a/tests/docker/test_server_token.py
+++ b/tests/docker/test_server_token.py
@@ -1,212 +1,208 @@
-import asyncio
import json
-from typing import Optional
+import sys
+from typing import AsyncGenerator, Optional, Union
from urllib.parse import quote
-async def get_token(session, email: str = "test@example.com") -> str:
+import aiohttp
+import pytest
+import pytest_asyncio
+from httpx import AsyncClient, Response, codes
+
+from tests.docker.common import TEST_URLS, async_client, markdown_params
+
+
+async def get_token(client: AsyncClient, email: str = "test@example.com") -> str:
"""Fetch a JWT token from the /token endpoint."""
- url = "http://localhost:8000/token"
- payload = {"email": email}
- print(f"\nFetching token from {url} with email: {email}")
- try:
- async with session.post(url, json=payload) as response:
- status = response.status
- data = await response.json()
- print(f"Token Response Status: {status}")
- print(f"Token Response: {json.dumps(data, indent=2)}")
- if status == 200:
- return data["access_token"]
- else:
- raise Exception(f"Failed to get token: {data.get('detail', 'Unknown error')}")
- except Exception as e:
- print(f"Error fetching token: {str(e)}")
- raise
-
-async def test_endpoint(
- session,
- endpoint: str,
- url: str,
- token: str,
- params: Optional[dict] = None,
- expected_status: int = 200
-) -> Optional[dict]:
- """Test an endpoint with token and print results."""
- params = params or {}
- param_str = "&".join(f"{k}={v}" for k, v in params.items())
- full_url = f"http://localhost:8000/{endpoint}/{quote(url)}"
- if param_str:
- full_url += f"?{param_str}"
-
- headers = {"Authorization": f"Bearer {token}"}
- print(f"\nTesting: {full_url}")
-
- try:
- async with session.get(full_url, headers=headers) as response:
- status = response.status
- try:
- data = await response.json()
- except:
- data = await response.text()
-
- print(f"Status: {status} (Expected: {expected_status})")
- if isinstance(data, dict):
- print(f"Response: {json.dumps(data, indent=2)}")
- else:
- print(f"Response: {data[:500]}...") # First 500 chars
- assert status == expected_status, f"Expected {expected_status}, got {status}"
- return data
- except Exception as e:
- print(f"Error: {str(e)}")
- return None
-
-
-async def test_stream_crawl(session, token: str):
- """Test the /crawl/stream endpoint with multiple URLs."""
- url = "http://localhost:8000/crawl/stream"
- payload = {
- "urls": [
- "https://example.com",
- "https://example.com/page1", # Replicated example.com with variation
- "https://example.com/page2", # Replicated example.com with variation
- "https://example.com/page3", # Replicated example.com with variation
- # "https://www.python.org",
- # "https://news.ycombinator.com/news"
- ],
- "browser_config": {"headless": True, "viewport": {"width": 1200}},
- "crawler_config": {"stream": True, "cache_mode": "bypass"}
- }
- headers = {"Authorization": f"Bearer {token}"}
- print(f"\nTesting Streaming Crawl: {url}")
- print(f"Payload: {json.dumps(payload, indent=2)}")
-
- try:
- async with session.post(url, json=payload, headers=headers) as response:
- status = response.status
- print(f"Status: {status} (Expected: 200)")
- assert status == 200, f"Expected 200, got {status}"
-
+ path: str = "/token"
+ payload: dict[str, str] = {"email": email}
+ print(f"\nFetching token from {path} with email: {email}")
+ response: Response = await client.post(path, json=payload)
+ data = response.json()
+ print(f"Token Response Status: {response.status_code}")
+ print(f"Token Response: {json.dumps(data, indent=2)}")
+ assert response.status_code == codes.OK
+ return data["access_token"]
+
+
+@pytest_asyncio.fixture(loop_scope="class")
+async def setup_session() -> AsyncGenerator[tuple[AsyncClient, str], None]:
+ async with async_client() as client:
+ token: str = await get_token(client)
+ assert token, "Failed to get token"
+
+ yield client, token
+
+
+@pytest.mark.asyncio(loop_scope="class")
+class TestAPI:
+ async def endpoint(
+ self,
+ client: AsyncClient,
+ token: str,
+ endpoint: str,
+ url: str,
+ params: Optional[dict] = None,
+ expected_status: int = codes.OK,
+ ) -> Union[dict, str]:
+ """Test an endpoint with token and print results."""
+ path = f"/{endpoint}/{quote(url)}"
+ if params:
+ path += "?" + "&".join(f"{k}={v}" for k, v in params.items())
+
+ headers = {"Authorization": f"Bearer {token}"}
+ print(f"\nTesting: {path}")
+
+ response: Response = await client.get(path, headers=headers)
+ content_type: str = response.headers.get(aiohttp.hdrs.CONTENT_TYPE, "").lower()
+ data: Union[dict, str] = (
+ response.json() if content_type == "application/json" else response.text
+ )
+ print(f"Status: {response.status_code} (Expected: {expected_status})")
+ if isinstance(data, dict):
+ print(f"Response: {json.dumps(data, indent=2)}")
+ else:
+ print(f"Response: {data[:500]}...") # First 500 chars
+ assert response.status_code == expected_status, (
+ f"Expected {expected_status}, got {response.status_code}"
+ )
+ return data
+
+ async def stream_crawl(self, client: AsyncClient, token: str):
+ """Test the /crawl/stream endpoint with multiple URLs."""
+ url = "/crawl/stream"
+ payload = {
+ "urls": [
+ "https://example.com",
+ "https://example.com/page1", # Replicated example.com with variation
+ "https://example.com/page2", # Replicated example.com with variation
+ "https://example.com/page3", # Replicated example.com with variation
+ ],
+ "browser_config": {"headless": True, "viewport": {"width": 1200}},
+ "crawler_config": {"stream": True, "cache_mode": "aggressive"},
+ }
+ headers = {"Authorization": f"Bearer {token}"}
+ print(f"\nTesting Streaming Crawl: {url}")
+ print(f"Payload: {json.dumps(payload, indent=2)}")
+
+ async with client.stream(
+ "POST", url, json=payload, headers=headers
+ ) as response:
+ print(f"Status: {response.status_code} (Expected: {codes.OK})")
+ assert response.status_code == codes.OK, (
+ f"Expected {codes.OK}, got {response.status_code}"
+ )
+
# Read streaming response line-by-line (NDJSON)
- async for line in response.content:
+ async for line in response.aiter_lines():
if line:
- data = json.loads(line.decode('utf-8').strip())
+ data = json.loads(line.strip())
print(f"Streamed Result: {json.dumps(data, indent=2)}")
- except Exception as e:
- print(f"Error in streaming crawl test: {str(e)}")
-
-async def run_tests():
- import aiohttp
- print("Starting API Tests...")
-
- # Test URLs
- urls = [
- "example.com",
- "https://www.python.org",
- "https://news.ycombinator.com/news",
- "https://github.com/trending"
- ]
-
- async with aiohttp.ClientSession() as session:
- # Fetch token once and reuse it
- token = await get_token(session)
- if not token:
- print("Aborting tests due to token failure!")
- return
-
- print("\n=== Testing Crawl Endpoint ===")
+
+ async def test_crawl_endpoint(self, setup_session: tuple[AsyncClient, str]):
+ client, token = setup_session
crawl_payload = {
"urls": ["https://example.com"],
"browser_config": {"headless": True},
- "crawler_config": {"stream": False}
+ "crawler_config": {"stream": False},
}
- async with session.post(
- "http://localhost:8000/crawl",
+ response = await client.post(
+ "/crawl",
json=crawl_payload,
- headers={"Authorization": f"Bearer {token}"}
- ) as response:
- status = response.status
- data = await response.json()
- print(f"\nCrawl Endpoint Status: {status}")
- print(f"Crawl Response: {json.dumps(data, indent=2)}")
-
-
- print("\n=== Testing Crawl Stream Endpoint ===")
- await test_stream_crawl(session, token)
-
- print("\n=== Testing Markdown Endpoint ===")
- for url in []: #urls:
- for filter_type in ["raw", "fit", "bm25", "llm"]:
- params = {"f": filter_type}
- if filter_type in ["bm25", "llm"]:
- params["q"] = "extract main content"
-
- for cache in ["0", "1"]:
- params["c"] = cache
- await test_endpoint(session, "md", url, token, params)
- await asyncio.sleep(1) # Be nice to the server
-
- print("\n=== Testing LLM Endpoint ===")
- for url in urls:
- # Test basic extraction (direct response now)
- result = await test_endpoint(
- session,
- "llm",
- url,
- token,
- {"q": "Extract title and main content"}
- )
-
- # Test with schema (direct response)
- schema = {
- "type": "object",
- "properties": {
- "title": {"type": "string"},
- "content": {"type": "string"},
- "links": {"type": "array", "items": {"type": "string"}}
- }
- }
- result = await test_endpoint(
- session,
- "llm",
- url,
- token,
- {
- "q": "Extract content with links",
- "s": json.dumps(schema),
- "c": "1" # Test with cache
- }
- )
- await asyncio.sleep(2) # Be nice to the server
-
- print("\n=== Testing Error Cases ===")
- # Test invalid URL
- await test_endpoint(
- session,
- "md",
- "not_a_real_url",
+ headers={"Authorization": f"Bearer {token}"},
+ )
+ data = response.json()
+ print(f"\nCrawl Endpoint Status: {response.status_code}")
+ print(f"Crawl Response: {json.dumps(data, indent=2)}")
+
+ @pytest.mark.asyncio
+ async def test_stream_endpoint(self, setup_session: tuple[AsyncClient, str]):
+ client, token = setup_session
+ await self.stream_crawl(client, token)
+
+ @pytest.mark.asyncio
+ @pytest.mark.parametrize("url,params", markdown_params())
+ @pytest.mark.timeout(60) # LLM tasks can take a while.
+ async def test_markdown_endpoint(
+ self,
+ url: str,
+ params: dict[str, str],
+ setup_session: tuple[AsyncClient, str],
+ ):
+ client, token = setup_session
+ result: Union[dict, str] = await self.endpoint(client, token, "md", url, params)
+ assert isinstance(result, str), "Expected str response"
+ assert result, "Expected non-empty response"
+
+ @pytest.mark.parametrize("url", TEST_URLS)
+ async def test_llm_endpoint_no_schema(
+ self, url: str, setup_session: tuple[AsyncClient, str]
+ ):
+ client, token = setup_session
+ # Test basic extraction (direct response now)
+ result = await self.endpoint(
+ client,
token,
- expected_status=500
+ "llm",
+ url,
+ {"q": "Extract title and main content"},
+ )
+ assert isinstance(result, dict), "Expected dict response"
+ assert "answer" in result, "Expected 'answer' key in response"
+
+ # Currently the server handles LLM requests using handle_llm_qa
+ # which doesn't use handle_llm_request which is where the schema
+ # is processed.
+ @pytest.mark.parametrize("url", TEST_URLS)
+ @pytest.mark.skip("LLM endpoint doesn't schema request yet")
+ async def test_llm_endpoint_schema(
+ self, url: str, setup_session: tuple[AsyncClient, str]
+ ):
+ client, token = setup_session
+ schema = {
+ "type": "object",
+ "properties": {
+ "title": {"type": "string"},
+ "content": {"type": "string"},
+ "links": {"type": "array", "items": {"type": "string"}},
+ },
+ }
+ result = await self.endpoint(
+ client,
+ token,
+ "llm",
+ url,
+ {
+ "q": "Extract content with links",
+ "s": json.dumps(schema),
+ "c": "1", # Test with cache
+ },
)
-
- # Test invalid filter type
- await test_endpoint(
- session,
+ assert isinstance(result, dict), "Expected dict response"
+ assert "answer" in result, "Expected 'answer' key in response"
+ print(result)
+
+ async def test_invalid_url(self, setup_session: tuple[AsyncClient, str]):
+ client, token = setup_session
+ await self.endpoint(client, token, "md", "not_a_real_url", expected_status=codes.INTERNAL_SERVER_ERROR)
+
+ async def test_invalid_filter(self, setup_session: tuple[AsyncClient, str]):
+ client, token = setup_session
+ await self.endpoint(
+ client,
+ token,
"md",
"example.com",
- token,
{"f": "invalid"},
- expected_status=422
+ expected_status=codes.UNPROCESSABLE_ENTITY,
)
-
+
+ async def test_missing_query(self, setup_session: tuple[AsyncClient, str]):
+ client, token = setup_session
# Test LLM without query (should fail per your server logic)
- await test_endpoint(
- session,
- "llm",
- "example.com",
- token,
- expected_status=400
- )
-
- print("\nAll tests completed!")
+ await self.endpoint(client, token, "llm", "example.com", expected_status=codes.BAD_REQUEST)
+
if __name__ == "__main__":
- asyncio.run(run_tests())
\ No newline at end of file
+ import subprocess
+
+ sys.exit(subprocess.call(["pytest", *sys.argv[1:], sys.argv[0]]))
diff --git a/tests/docker_example.py b/tests/docker_example.py
index 336ca52f6..49b878d9e 100644
--- a/tests/docker_example.py
+++ b/tests/docker_example.py
@@ -1,14 +1,17 @@
+from httpx import codes
import requests
import json
import time
import sys
import base64
import os
-from typing import Dict, Any
+from typing import Dict, Any, Optional
class Crawl4AiTester:
- def __init__(self, base_url: str = "http://localhost:11235", api_token: str = None):
+ def __init__(
+ self, base_url: str = "http://localhost:11235", api_token: Optional[str] = None
+ ):
self.base_url = base_url
self.api_token = api_token or os.getenv(
"CRAWL4AI_API_TOKEN"
@@ -24,7 +27,7 @@ def submit_and_wait(
response = requests.post(
f"{self.base_url}/crawl", json=request_data, headers=self.headers
)
- if response.status_code == 403:
+ if response.status_code == codes.FORBIDDEN:
raise Exception("API token is invalid or missing")
task_id = response.json()["task_id"]
print(f"Task ID: {task_id}")
@@ -58,7 +61,7 @@ def submit_sync(self, request_data: Dict[str, Any]) -> Dict[str, Any]:
headers=self.headers,
timeout=60,
)
- if response.status_code == 408:
+ if response.status_code == codes.REQUEST_TIMEOUT:
raise TimeoutError("Task did not complete within server timeout")
response.raise_for_status()
return response.json()
@@ -66,7 +69,6 @@ def submit_sync(self, request_data: Dict[str, Any]) -> Dict[str, Any]:
def test_docker_deployment(version="basic"):
tester = Crawl4AiTester(
- # base_url="http://localhost:11235" ,
base_url="https://crawl4ai-sby74.ondigitalocean.app",
api_token="test",
)
diff --git a/tests/hub/test_simple.py b/tests/hub/test_simple.py
index a970d683c..a3d45b174 100644
--- a/tests/hub/test_simple.py
+++ b/tests/hub/test_simple.py
@@ -1,34 +1,50 @@
# test.py
-from crawl4ai import CrawlerHub
import json
+import sys
+from pathlib import Path
+
+import pytest
+
+from crawl4ai import CrawlerHub
+
-async def amazon_example():
- if (crawler_cls := CrawlerHub.get("amazon_product")) :
- crawler = crawler_cls()
- print(f"Crawler version: {crawler_cls.meta['version']}")
- print(f"Rate limits: {crawler_cls.meta.get('rate_limit', 'Unlimited')}")
- print(await crawler.run("https://amazon.com/test"))
- else:
- print("Crawler not found!")
+@pytest.mark.asyncio
+async def test_amazon():
+ crawler_cls = CrawlerHub.get("amazon_product")
+ assert crawler_cls is not None
-async def google_example():
+ crawler = crawler_cls()
+ print(f"Crawler version: {crawler.meta['version']}")
+ print(f"Rate limits: {crawler.meta.get('rate_limit', 'Unlimited')}")
+ result = await crawler.run("https://amazon.com/test")
+ assert result and "product" in result
+ print(result)
+
+
+@pytest.mark.asyncio
+@pytest.mark.skip("crawler doesn't pass llm_config to generate_schema, so it fails")
+async def test_google(tmp_path: Path):
# Get crawler dynamically
crawler_cls = CrawlerHub.get("google_search")
+ assert crawler_cls is not None
crawler = crawler_cls()
# Text search
+ schema_cache_path: Path = tmp_path / ".crawl4ai"
text_results = await crawler.run(
- query="apple inc",
- search_type="text",
- schema_cache_path="/Users/unclecode/.crawl4ai"
+ query="apple inc",
+ search_type="text",
+ schema_cache_path=schema_cache_path.as_posix(),
)
print(json.dumps(json.loads(text_results), indent=4))
+ assert text_results
# Image search
# image_results = await crawler.run(query="apple inc", search_type="image")
# print(image_results)
+
if __name__ == "__main__":
- import asyncio
- # asyncio.run(amazon_example())
- asyncio.run(google_example())
\ No newline at end of file
+ import subprocess
+
+ sys.exit(subprocess.call(["pytest", *sys.argv[1:], sys.argv[0]]))
diff --git a/tests/legacy/test_cli_docs.py b/tests/legacy/test_cli_docs.py
new file mode 100644
index 000000000..be70d7c80
--- /dev/null
+++ b/tests/legacy/test_cli_docs.py
@@ -0,0 +1,20 @@
+import sys
+
+import pytest
+
+from crawl4ai.legacy.docs_manager import DocsManager
+
+
+@pytest.mark.asyncio
+async def test_cli():
+ """Test all CLI commands"""
+ print("\n1. Testing docs update...")
+ docs_manager = DocsManager()
+ result = await docs_manager.fetch_docs()
+ assert result
+
+
+if __name__ == "__main__":
+ import subprocess
+
+ sys.exit(subprocess.call(["pytest", *sys.argv[1:], sys.argv[0]]))
diff --git a/tests/loggers/test_logger.py b/tests/loggers/test_logger.py
index 6c3a811b4..a8a074592 100644
--- a/tests/loggers/test_logger.py
+++ b/tests/loggers/test_logger.py
@@ -1,7 +1,18 @@
-import asyncio
-from crawl4ai import AsyncWebCrawler, BrowserConfig, CrawlerRunConfig, CacheMode, AsyncLoggerBase
import os
+import sys
from datetime import datetime
+from pathlib import Path
+
+import pytest
+
+from crawl4ai import (
+ AsyncLoggerBase,
+ AsyncWebCrawler,
+ BrowserConfig,
+ CacheMode,
+ CrawlerRunConfig,
+)
+
class AsyncFileLogger(AsyncLoggerBase):
"""
@@ -55,26 +66,25 @@ def error_status(self, url: str, error: str, tag: str = "ERROR", url_length: int
message = f"{url[:url_length]}... | Error: {error}"
self._write_to_file("ERROR", message, tag)
-async def main():
+
+@pytest.mark.asyncio
+async def test_logger(tmp_path: Path):
+ log_file: Path = tmp_path / "test.log"
browser_config = BrowserConfig(headless=True, verbose=True)
- crawler = AsyncWebCrawler(config=browser_config, logger=AsyncFileLogger("/Users/unclecode/devs/crawl4ai/.private/tmp/crawl.log"))
- await crawler.start()
-
- try:
+ async with AsyncWebCrawler(config=browser_config,logger=AsyncFileLogger(log_file.as_posix())) as crawler:
crawl_config = CrawlerRunConfig(
cache_mode=CacheMode.BYPASS,
)
+
# Use the crawler multiple times
- result = await crawler.arun(
- url='https://kidocode.com/',
- config=crawl_config
- )
- if result.success:
- print("First crawl - Raw Markdown Length:", len(result.markdown.raw_markdown))
-
- finally:
- # Always ensure we close the crawler
- await crawler.close()
+ for _ in range(3):
+ result = await crawler.arun(
+ url='https://kidocode.com/',
+ config=crawl_config
+ )
+ assert result.success
if __name__ == "__main__":
- asyncio.run(main())
+ import subprocess
+
+ sys.exit(subprocess.call(["pytest", *sys.argv[1:], sys.argv[0]]))
diff --git a/tests/memory/test_crawler_monitor.py b/tests/memory/test_crawler_monitor.py
index 89cc08b84..b186bf630 100644
--- a/tests/memory/test_crawler_monitor.py
+++ b/tests/memory/test_crawler_monitor.py
@@ -9,6 +9,7 @@
import threading
import sys
import os
+import pytest
# Add the parent directory to the path to import crawl4ai
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))
@@ -59,7 +60,7 @@ def simulate_crawler_task(monitor, task_id, url, simulate_failure=False):
memory_usage=0.0
)
-def update_queue_stats(monitor, num_queued_tasks):
+def update_queue_stats(monitor: CrawlerMonitor, num_queued_tasks):
"""Update queue statistics periodically."""
while monitor.is_running:
queued_tasks = [
@@ -102,6 +103,7 @@ def update_queue_stats(monitor, num_queued_tasks):
time.sleep(1.0)
+@pytest.mark.timeout(60)
def test_crawler_monitor():
"""Test the CrawlerMonitor with simulated crawler tasks."""
# Total number of URLs to crawl
@@ -165,4 +167,6 @@ def test_crawler_monitor():
print("\nCrawler monitor test completed")
if __name__ == "__main__":
+ # Doesn't use the standard pytest entry point as it's not compatible
+ # with the tty output in the terminal.
test_crawler_monitor()
\ No newline at end of file
diff --git a/tests/memory/test_dispatcher_stress.py b/tests/memory/test_dispatcher_stress.py
index f81f78f65..684546cc2 100644
--- a/tests/memory/test_dispatcher_stress.py
+++ b/tests/memory/test_dispatcher_stress.py
@@ -3,13 +3,14 @@
import psutil
import logging
import random
-from typing import List, Dict
+from typing import List, Dict, Optional
import uuid
import sys
import os
+import pytest
# Import your crawler components
-from crawl4ai.models import DisplayMode, CrawlStatus, CrawlResult
+from crawl4ai.models import CrawlStatus
from crawl4ai.async_configs import CrawlerRunConfig, BrowserConfig, CacheMode
from crawl4ai import AsyncWebCrawler
from crawl4ai import MemoryAdaptiveDispatcher, CrawlerMonitor
@@ -73,8 +74,8 @@ def apply_pressure(self, additional_percent: float = 0.0):
time.sleep(0.5) # Give system time to register the allocation
except MemoryError:
logger.warning("Unable to allocate more memory")
-
- def release_pressure(self, percent: float = None):
+
+ def release_pressure(self, percent: Optional[float] = None):
"""
Release allocated memory
If percent is specified, release that percentage of blocks
@@ -101,24 +102,24 @@ def spike_pressure(self, duration: float = 5.0):
logger.info(f"Creating memory pressure spike for {duration} seconds")
# Save current blocks count
initial_blocks = len(self.memory_blocks)
-
+
# Create spike with extra 5%
self.apply_pressure(additional_percent=5.0)
-
+
# Schedule release after duration
asyncio.create_task(self._delayed_release(duration, initial_blocks))
-
+
async def _delayed_release(self, delay: float, target_blocks: int):
"""Helper for spike_pressure - releases extra blocks after delay"""
await asyncio.sleep(delay)
-
+
# Remove blocks added since spike started
if len(self.memory_blocks) > target_blocks:
logger.info(f"Releasing memory spike ({len(self.memory_blocks) - target_blocks} blocks)")
self.memory_blocks = self.memory_blocks[:target_blocks]
-
+
# Test statistics collector
-class TestResults:
+class StressTestResults:
def __init__(self):
self.start_time = time.time()
self.completed_urls: List[str] = []
@@ -167,7 +168,7 @@ def log_summary(self):
# Custom monitor with stats tracking
# Custom monitor that extends CrawlerMonitor with test-specific tracking
class StressTestMonitor(CrawlerMonitor):
- def __init__(self, test_results: TestResults, **kwargs):
+ def __init__(self, test_results: StressTestResults, **kwargs):
# Initialize the parent CrawlerMonitor
super().__init__(**kwargs)
self.test_results = test_results
@@ -192,32 +193,51 @@ def update_queue_statistics(self, total_queued: int, highest_wait_time: float, a
# Call parent method to update the dashboard
super().update_queue_statistics(total_queued, highest_wait_time, avg_wait_time)
-
- def update_task(self, task_id: str, **kwargs):
+
+ def update_task(
+ self,
+ task_id: str,
+ status: Optional[CrawlStatus] = None,
+ start_time: Optional[float] = None,
+ end_time: Optional[float] = None,
+ memory_usage: Optional[float] = None,
+ peak_memory: Optional[float] = None,
+ error_message: Optional[str] = None,
+ retry_count: Optional[int] = None,
+ wait_time: Optional[float] = None
+ ):
# Track URL status changes for test results
if task_id in self.stats:
- old_status = self.stats[task_id].status
-
+ old_status = self.stats[task_id]["status"]
+
# If this is a requeue event (requeued due to memory pressure)
- if 'error_message' in kwargs and 'requeued' in kwargs['error_message']:
- if not hasattr(self.stats[task_id], 'counted_requeue') or not self.stats[task_id].counted_requeue:
+ if error_message and 'requeued' in error_message:
+ if not self.stats[task_id]["counted_requeue"]:
self.test_results.requeued_count += 1
- self.stats[task_id].counted_requeue = True
-
+ self.stats[task_id]["counted_requeue"] = True
+
# Track completion status for test results
- if 'status' in kwargs:
- new_status = kwargs['status']
- if old_status != new_status:
- if new_status == CrawlStatus.COMPLETED:
+ if status:
+ if old_status != status:
+ if status == CrawlStatus.COMPLETED:
if task_id not in self.test_results.completed_urls:
self.test_results.completed_urls.append(task_id)
- elif new_status == CrawlStatus.FAILED:
+ elif status == CrawlStatus.FAILED:
if task_id not in self.test_results.failed_urls:
self.test_results.failed_urls.append(task_id)
# Call parent method to update the dashboard
- super().update_task(task_id, **kwargs)
- self.live.update(self._create_table())
+ super().update_task(
+ task_id=task_id,
+ status=status,
+ start_time=start_time,
+ end_time=end_time,
+ memory_usage=memory_usage,
+ peak_memory=peak_memory,
+ error_message=error_message,
+ retry_count=retry_count,
+ wait_time=wait_time
+ )
# Generate test URLs - use example.com with unique paths to avoid browser caching
def generate_test_urls(count: int) -> List[str]:
@@ -230,7 +250,7 @@ def generate_test_urls(count: int) -> List[str]:
return urls
# Process result callback
-async def process_result(result, test_results: TestResults):
+async def process_result(result, test_results: StressTestResults):
# Track attempt counts
if result.url not in test_results.url_to_attempt:
test_results.url_to_attempt[result.url] = 1
@@ -248,7 +268,7 @@ async def process_result(result, test_results: TestResults):
logger.warning(f"Failed to process: {result.url} - {result.error_message}")
# Process multiple results (used in non-streaming mode)
-async def process_results(results, test_results: TestResults):
+async def process_results(results, test_results: StressTestResults):
for result in results:
await process_result(result, test_results)
@@ -260,7 +280,25 @@ async def run_memory_stress_test(
aggressive: bool = False,
spikes: bool = True
):
- test_results = TestResults()
+ # Scale values based on initial memory usage.
+ # With no initial memory pressure, we can use the default values:
+ # - 55% is the threshold for normal operation - plenty of memory
+ # - 63% is the threshold for throttling
+ # - 70% is the threshold for requeuing - incredibly aggressive
+ initial_percent: float = psutil.virtual_memory().percent
+ baseline: float = 55.0
+ if initial_percent > baseline:
+ baseline = initial_percent + 5.0
+
+ memory_threshold_percent = baseline + 8
+ critical_threshold_percent = baseline + 15
+ recovery_threshold_percent = baseline
+
+ if critical_threshold_percent > 99.0:
+ pytest.skip("Memory pressure too high for this system")
+
+ test_results = StressTestResults()
+
memory_simulator = MemorySimulator(target_percent=target_memory_percent, aggressive=aggressive)
logger.info(f"Starting stress test with {url_count} URLs in {'STREAM' if STREAM else 'NON-STREAM'} mode")
@@ -285,17 +323,15 @@ async def run_memory_stress_test(
# Create monitor with reference to test results
monitor = StressTestMonitor(
test_results=test_results,
- display_mode=DisplayMode.DETAILED,
- max_visible_rows=20,
- total_urls=url_count # Pass total URLs count
+ urls_total=url_count # Pass total URLs count
)
# Create dispatcher with EXTREME settings - pure survival mode
# These settings are designed to create a memory battleground
dispatcher = MemoryAdaptiveDispatcher(
- memory_threshold_percent=63.0, # Start throttling at just 60% memory
- critical_threshold_percent=70.0, # Start requeuing at 70% - incredibly aggressive
- recovery_threshold_percent=55.0, # Only resume normal ops when plenty of memory available
+ memory_threshold_percent=memory_threshold_percent,
+ critical_threshold_percent=critical_threshold_percent,
+ recovery_threshold_percent=recovery_threshold_percent,
check_interval=0.1, # Check extremely frequently (100ms)
max_session_permit=20 if aggressive else 10, # Double the concurrent sessions - pure chaos
fairness_timeout=10.0, # Extremely low timeout - rapid priority changes
@@ -303,8 +339,8 @@ async def run_memory_stress_test(
)
# Set up spike schedule if enabled
+ spike_intervals = []
if spikes:
- spike_intervals = []
# Create 3-5 random spike times
num_spikes = random.randint(3, 5)
for _ in range(num_spikes):
@@ -379,6 +415,12 @@ async def run_memory_stress_test(
logger.info("TEST PASSED: All URLs were processed without crashing.")
return True
+@pytest.mark.asyncio
+@pytest.mark.timeout(600)
+async def test_memory_stress():
+ """Run the memory stress test with default parameters."""
+ assert await run_memory_stress_test()
+
# Command-line entry point
if __name__ == "__main__":
# Parse command line arguments
diff --git a/tests/test_cli_docs.py b/tests/test_cli_docs.py
deleted file mode 100644
index 6941f20db..000000000
--- a/tests/test_cli_docs.py
+++ /dev/null
@@ -1,44 +0,0 @@
-import asyncio
-from crawl4ai.docs_manager import DocsManager
-from click.testing import CliRunner
-from crawl4ai.cli import cli
-
-
-def test_cli():
- """Test all CLI commands"""
- runner = CliRunner()
-
- print("\n1. Testing docs update...")
- # Use sync version for testing
- docs_manager = DocsManager()
- loop = asyncio.get_event_loop()
- loop.run_until_complete(docs_manager.fetch_docs())
-
- # print("\n2. Testing listing...")
- # result = runner.invoke(cli, ['docs', 'list'])
- # print(f"Status: {'✅' if result.exit_code == 0 else '❌'}")
- # print(result.output)
-
- # print("\n2. Testing index building...")
- # result = runner.invoke(cli, ['docs', 'index'])
- # print(f"Status: {'✅' if result.exit_code == 0 else '❌'}")
- # print(f"Output: {result.output}")
-
- # print("\n3. Testing search...")
- # result = runner.invoke(cli, ['docs', 'search', 'how to use crawler', '--build-index'])
- # print(f"Status: {'✅' if result.exit_code == 0 else '❌'}")
- # print(f"First 200 chars: {result.output[:200]}...")
-
- # print("\n4. Testing combine with sections...")
- # result = runner.invoke(cli, ['docs', 'combine', 'chunking_strategies', 'extraction_strategies', '--mode', 'extended'])
- # print(f"Status: {'✅' if result.exit_code == 0 else '❌'}")
- # print(f"First 200 chars: {result.output[:200]}...")
-
- print("\n5. Testing combine all sections...")
- result = runner.invoke(cli, ["docs", "combine", "--mode", "condensed"])
- print(f"Status: {'✅' if result.exit_code == 0 else '❌'}")
- print(f"First 200 chars: {result.output[:200]}...")
-
-
-if __name__ == "__main__":
- test_cli()
diff --git a/tests/test_crawl_result_container.py b/tests/test_crawl_result_container.py
new file mode 100644
index 000000000..84cd3b7ae
--- /dev/null
+++ b/tests/test_crawl_result_container.py
@@ -0,0 +1,133 @@
+from typing import Any, AsyncGenerator
+
+import pytest
+from _pytest.mark.structures import ParameterSet
+
+from crawl4ai.models import CrawlResultContainer, CrawlResult
+
+RESULT: CrawlResult = CrawlResult(
+ url="https://example.com", success=True, html="Test content"
+)
+
+
+def result_container_params() -> list[ParameterSet]:
+ """Return a list of test parameters for the CrawlResultContainer tests.
+
+ :return: List of test parameters for CrawlResultContainer tests containing tuple[result:CrawlResultContainer, expected:list[CrawlResult]]
+ :rtype: list[ParameterSet]
+ """
+
+ async def async_generator(results: list[CrawlResult]) -> AsyncGenerator[CrawlResult, Any]:
+ for result in results:
+ yield result
+
+ return [
+ pytest.param(CrawlResultContainer(RESULT), [RESULT], id="result"),
+ pytest.param(CrawlResultContainer([]), [], id="list_empty"),
+ pytest.param(CrawlResultContainer([RESULT]), [RESULT], id="list_single"),
+ pytest.param(CrawlResultContainer([RESULT, RESULT]), [RESULT, RESULT], id="list_multi"),
+ pytest.param(CrawlResultContainer(async_generator([])), [], id="async_empty"),
+ pytest.param(CrawlResultContainer(async_generator([RESULT])), [RESULT], id="async_single"),
+ pytest.param(CrawlResultContainer(async_generator([RESULT, RESULT])), [RESULT, RESULT], id="async_multi"),
+ ]
+
+
+@pytest.mark.parametrize("result,expected", result_container_params())
+def test_iter(result: CrawlResultContainer, expected: list[CrawlResult]):
+ """Test __iter__ of the CrawlResultContainer."""
+ if isinstance(result.source, AsyncGenerator):
+ with pytest.raises(TypeError):
+ for item in result:
+ pass
+ return
+
+ results: list[CrawlResult] = []
+ for item in result:
+ results.append(item)
+
+ assert results == expected
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("result,expected", result_container_params())
+async def test_aiter(result: CrawlResultContainer, expected: list[CrawlResult]):
+ """Test __aiter__ of the CrawlResultContainer."""
+ results: list[CrawlResult] = []
+ async for item in result:
+ results.append(item)
+
+ assert results == expected
+
+
+@pytest.mark.parametrize("result,expected", result_container_params())
+def test_getitem(result: CrawlResultContainer, expected: list[CrawlResult]):
+ """Test the __getitem__ method of the CrawlResultContainer."""
+ if isinstance(result.source, AsyncGenerator):
+ with pytest.raises(TypeError):
+ assert result[0] == expected[0]
+
+ return
+
+ for i in range(len(expected)):
+ assert result[i] == expected[i]
+
+
+@pytest.mark.parametrize("result,expected", result_container_params())
+def test_len(result: CrawlResultContainer, expected: list[CrawlResult]):
+ """Test the __len__ of the CrawlResultContainer."""
+ if isinstance(result.source, AsyncGenerator):
+ with pytest.raises(TypeError):
+ assert len(result) == len(expected)
+
+ return
+
+ assert len(result) == len(expected)
+
+
+def result_attributes() -> list[str]:
+ """Return a list of attributes to test for CrawlResult
+
+ :return: List of valid attributes, excluding private, callable, and deprecated attributes.
+ :rtype: list[str]
+ """
+
+ # We check hasattr to avoid class only attribute error and failing on deprecated attributes.
+ return [
+ attr
+ for attr in dir(RESULT)
+ if attr in RESULT.model_fields or (hasattr(RESULT, attr) and isinstance(getattr(RESULT, attr), property))
+ ]
+
+
+@pytest.mark.parametrize("result,expected", result_container_params())
+def test_getattribute(result: CrawlResultContainer, expected: list[CrawlResult]):
+ """Test the __getattribute__ method of the CrawlResultContainer."""
+ assert result.source is not None
+
+ if isinstance(result.source, AsyncGenerator):
+ for attr in result_attributes():
+ with pytest.raises(TypeError):
+ assert getattr(result, attr) == getattr(RESULT, attr)
+
+ return
+
+ if not expected:
+ for attr in result_attributes():
+ with pytest.raises(AttributeError):
+ assert getattr(result, attr) == getattr(RESULT, attr)
+
+ return
+
+ for attr in result_attributes():
+ assert getattr(result, attr) == getattr(RESULT, attr)
+
+
+@pytest.mark.parametrize("result,expected", result_container_params())
+def test_repr(result: CrawlResultContainer, expected: list[CrawlResult]):
+ """Test the __repr__ method of the CrawlResultContainer."""
+ if isinstance(result.source, AsyncGenerator):
+ assert repr(result) == "CrawlResultContainer([])"
+
+ return
+
+ assert repr(result) == f"CrawlResultContainer({repr(expected)})"
diff --git a/tests/test_docker.py b/tests/test_docker.py
deleted file mode 100644
index 3570d608d..000000000
--- a/tests/test_docker.py
+++ /dev/null
@@ -1,299 +0,0 @@
-import requests
-import json
-import time
-import sys
-import base64
-import os
-from typing import Dict, Any
-
-
-class Crawl4AiTester:
- def __init__(self, base_url: str = "http://localhost:11235"):
- self.base_url = base_url
-
- def submit_and_wait(
- self, request_data: Dict[str, Any], timeout: int = 300
- ) -> Dict[str, Any]:
- # Submit crawl job
- response = requests.post(f"{self.base_url}/crawl", json=request_data)
- task_id = response.json()["task_id"]
- print(f"Task ID: {task_id}")
-
- # Poll for result
- start_time = time.time()
- while True:
- if time.time() - start_time > timeout:
- raise TimeoutError(
- f"Task {task_id} did not complete within {timeout} seconds"
- )
-
- result = requests.get(f"{self.base_url}/task/{task_id}")
- status = result.json()
-
- if status["status"] == "failed":
- print("Task failed:", status.get("error"))
- raise Exception(f"Task failed: {status.get('error')}")
-
- if status["status"] == "completed":
- return status
-
- time.sleep(2)
-
-
-def test_docker_deployment(version="basic"):
- tester = Crawl4AiTester()
- print(f"Testing Crawl4AI Docker {version} version")
-
- # Health check with timeout and retry
- max_retries = 5
- for i in range(max_retries):
- try:
- health = requests.get(f"{tester.base_url}/health", timeout=10)
- print("Health check:", health.json())
- break
- except requests.exceptions.RequestException:
- if i == max_retries - 1:
- print(f"Failed to connect after {max_retries} attempts")
- sys.exit(1)
- print(f"Waiting for service to start (attempt {i+1}/{max_retries})...")
- time.sleep(5)
-
- # Test cases based on version
- test_basic_crawl(tester)
-
- # if version in ["full", "transformer"]:
- # test_cosine_extraction(tester)
-
- # test_js_execution(tester)
- # test_css_selector(tester)
- # test_structured_extraction(tester)
- # test_llm_extraction(tester)
- # test_llm_with_ollama(tester)
- # test_screenshot(tester)
-
-
-def test_basic_crawl(tester: Crawl4AiTester):
- print("\n=== Testing Basic Crawl ===")
- request = {"urls": "https://www.nbcnews.com/business", "priority": 10}
-
- result = tester.submit_and_wait(request)
- print(f"Basic crawl result length: {len(result['result']['markdown'])}")
- assert result["result"]["success"]
- assert len(result["result"]["markdown"]) > 0
-
-
-def test_js_execution(tester: Crawl4AiTester):
- print("\n=== Testing JS Execution ===")
- request = {
- "urls": "https://www.nbcnews.com/business",
- "priority": 8,
- "js_code": [
- "const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More')); loadMoreButton && loadMoreButton.click();"
- ],
- "wait_for": "article.tease-card:nth-child(10)",
- "crawler_params": {"headless": True},
- }
-
- result = tester.submit_and_wait(request)
- print(f"JS execution result length: {len(result['result']['markdown'])}")
- assert result["result"]["success"]
-
-
-def test_css_selector(tester: Crawl4AiTester):
- print("\n=== Testing CSS Selector ===")
- request = {
- "urls": "https://www.nbcnews.com/business",
- "priority": 7,
- "css_selector": ".wide-tease-item__description",
- "crawler_params": {"headless": True},
- "extra": {"word_count_threshold": 10},
- }
-
- result = tester.submit_and_wait(request)
- print(f"CSS selector result length: {len(result['result']['markdown'])}")
- assert result["result"]["success"]
-
-
-def test_structured_extraction(tester: Crawl4AiTester):
- print("\n=== Testing Structured Extraction ===")
- schema = {
- "name": "Coinbase Crypto Prices",
- "baseSelector": ".cds-tableRow-t45thuk",
- "fields": [
- {
- "name": "crypto",
- "selector": "td:nth-child(1) h2",
- "type": "text",
- },
- {
- "name": "symbol",
- "selector": "td:nth-child(1) p",
- "type": "text",
- },
- {
- "name": "price",
- "selector": "td:nth-child(2)",
- "type": "text",
- },
- ],
- }
-
- request = {
- "urls": "https://www.coinbase.com/explore",
- "priority": 9,
- "extraction_config": {"type": "json_css", "params": {"schema": schema}},
- }
-
- result = tester.submit_and_wait(request)
- extracted = json.loads(result["result"]["extracted_content"])
- print(f"Extracted {len(extracted)} items")
- print("Sample item:", json.dumps(extracted[0], indent=2))
- assert result["result"]["success"]
- assert len(extracted) > 0
-
-
-def test_llm_extraction(tester: Crawl4AiTester):
- print("\n=== Testing LLM Extraction ===")
- schema = {
- "type": "object",
- "properties": {
- "model_name": {
- "type": "string",
- "description": "Name of the OpenAI model.",
- },
- "input_fee": {
- "type": "string",
- "description": "Fee for input token for the OpenAI model.",
- },
- "output_fee": {
- "type": "string",
- "description": "Fee for output token for the OpenAI model.",
- },
- },
- "required": ["model_name", "input_fee", "output_fee"],
- }
-
- request = {
- "urls": "https://openai.com/api/pricing",
- "priority": 8,
- "extraction_config": {
- "type": "llm",
- "params": {
- "provider": "openai/gpt-4o-mini",
- "api_token": os.getenv("OPENAI_API_KEY"),
- "schema": schema,
- "extraction_type": "schema",
- "instruction": """From the crawled content, extract all mentioned model names along with their fees for input and output tokens.""",
- },
- },
- "crawler_params": {"word_count_threshold": 1},
- }
-
- try:
- result = tester.submit_and_wait(request)
- extracted = json.loads(result["result"]["extracted_content"])
- print(f"Extracted {len(extracted)} model pricing entries")
- print("Sample entry:", json.dumps(extracted[0], indent=2))
- assert result["result"]["success"]
- except Exception as e:
- print(f"LLM extraction test failed (might be due to missing API key): {str(e)}")
-
-
-def test_llm_with_ollama(tester: Crawl4AiTester):
- print("\n=== Testing LLM with Ollama ===")
- schema = {
- "type": "object",
- "properties": {
- "article_title": {
- "type": "string",
- "description": "The main title of the news article",
- },
- "summary": {
- "type": "string",
- "description": "A brief summary of the article content",
- },
- "main_topics": {
- "type": "array",
- "items": {"type": "string"},
- "description": "Main topics or themes discussed in the article",
- },
- },
- }
-
- request = {
- "urls": "https://www.nbcnews.com/business",
- "priority": 8,
- "extraction_config": {
- "type": "llm",
- "params": {
- "provider": "ollama/llama2",
- "schema": schema,
- "extraction_type": "schema",
- "instruction": "Extract the main article information including title, summary, and main topics.",
- },
- },
- "extra": {"word_count_threshold": 1},
- "crawler_params": {"verbose": True},
- }
-
- try:
- result = tester.submit_and_wait(request)
- extracted = json.loads(result["result"]["extracted_content"])
- print("Extracted content:", json.dumps(extracted, indent=2))
- assert result["result"]["success"]
- except Exception as e:
- print(f"Ollama extraction test failed: {str(e)}")
-
-
-def test_cosine_extraction(tester: Crawl4AiTester):
- print("\n=== Testing Cosine Extraction ===")
- request = {
- "urls": "https://www.nbcnews.com/business",
- "priority": 8,
- "extraction_config": {
- "type": "cosine",
- "params": {
- "semantic_filter": "business finance economy",
- "word_count_threshold": 10,
- "max_dist": 0.2,
- "top_k": 3,
- },
- },
- }
-
- try:
- result = tester.submit_and_wait(request)
- extracted = json.loads(result["result"]["extracted_content"])
- print(f"Extracted {len(extracted)} text clusters")
- print("First cluster tags:", extracted[0]["tags"])
- assert result["result"]["success"]
- except Exception as e:
- print(f"Cosine extraction test failed: {str(e)}")
-
-
-def test_screenshot(tester: Crawl4AiTester):
- print("\n=== Testing Screenshot ===")
- request = {
- "urls": "https://www.nbcnews.com/business",
- "priority": 5,
- "screenshot": True,
- "crawler_params": {"headless": True},
- }
-
- result = tester.submit_and_wait(request)
- print("Screenshot captured:", bool(result["result"]["screenshot"]))
-
- if result["result"]["screenshot"]:
- # Save screenshot
- screenshot_data = base64.b64decode(result["result"]["screenshot"])
- with open("test_screenshot.jpg", "wb") as f:
- f.write(screenshot_data)
- print("Screenshot saved as test_screenshot.jpg")
-
- assert result["result"]["success"]
-
-
-if __name__ == "__main__":
- version = sys.argv[1] if len(sys.argv) > 1 else "basic"
- # version = "full"
- test_docker_deployment(version)
diff --git a/tests/test_llmtxt.py b/tests/test_llmtxt.py
index 2cdb02715..6910a2f43 100644
--- a/tests/test_llmtxt.py
+++ b/tests/test_llmtxt.py
@@ -1,12 +1,15 @@
-from crawl4ai.llmtxt import AsyncLLMTextManager # Changed to AsyncLLMTextManager
-from crawl4ai.async_logger import AsyncLogger
+import sys
from pathlib import Path
-import asyncio
+import pytest
+
+from crawl4ai.async_logger import AsyncLogger
+from crawl4ai.legacy.llmtxt import AsyncLLMTextManager
-async def main():
+
+@pytest.mark.asyncio
+async def test_llm_txt():
current_file = Path(__file__).resolve()
- # base_dir = current_file.parent.parent / "local/_docs/llm.txt/test_docs"
base_dir = current_file.parent.parent / "local/_docs/llm.txt"
docs_dir = base_dir
@@ -16,7 +19,6 @@ async def main():
# Initialize logger
logger = AsyncLogger()
# Updated initialization with default batching params
- # manager = AsyncLLMTextManager(docs_dir, logger, max_concurrent_calls=3, batch_size=2)
manager = AsyncLLMTextManager(docs_dir, logger, batch_size=2)
# Let's first check what files we have
@@ -39,14 +41,13 @@ async def main():
for query in test_queries:
print(f"\nQuery: {query}")
results = manager.search(query, top_k=2)
- print(f"Results length: {len(results)} characters")
- if results:
- print(
- "First 200 chars of results:", results[:200].replace("\n", " "), "..."
- )
- else:
- print("No results found")
+ assert results, "No results found"
+ print(
+ "First 200 chars of results:", results[:200].replace("\n", " "), "..."
+ )
if __name__ == "__main__":
- asyncio.run(main())
+ import subprocess
+
+ sys.exit(subprocess.call(["pytest", *sys.argv[1:], sys.argv[0]]))
diff --git a/tests/test_main.py b/tests/test_main.py
deleted file mode 100644
index 0e938f590..000000000
--- a/tests/test_main.py
+++ /dev/null
@@ -1,276 +0,0 @@
-import asyncio
-import aiohttp
-import json
-import time
-import os
-from typing import Dict, Any
-
-
-class NBCNewsAPITest:
- def __init__(self, base_url: str = "http://localhost:8000"):
- self.base_url = base_url
- self.session = None
-
- async def __aenter__(self):
- self.session = aiohttp.ClientSession()
- return self
-
- async def __aexit__(self, exc_type, exc_val, exc_tb):
- if self.session:
- await self.session.close()
-
- async def submit_crawl(self, request_data: Dict[str, Any]) -> str:
- async with self.session.post(
- f"{self.base_url}/crawl", json=request_data
- ) as response:
- result = await response.json()
- return result["task_id"]
-
- async def get_task_status(self, task_id: str) -> Dict[str, Any]:
- async with self.session.get(f"{self.base_url}/task/{task_id}") as response:
- return await response.json()
-
- async def wait_for_task(
- self, task_id: str, timeout: int = 300, poll_interval: int = 2
- ) -> Dict[str, Any]:
- start_time = time.time()
- while True:
- if time.time() - start_time > timeout:
- raise TimeoutError(
- f"Task {task_id} did not complete within {timeout} seconds"
- )
-
- status = await self.get_task_status(task_id)
- if status["status"] in ["completed", "failed"]:
- return status
-
- await asyncio.sleep(poll_interval)
-
- async def check_health(self) -> Dict[str, Any]:
- async with self.session.get(f"{self.base_url}/health") as response:
- return await response.json()
-
-
-async def test_basic_crawl():
- print("\n=== Testing Basic Crawl ===")
- async with NBCNewsAPITest() as api:
- request = {"urls": "https://www.nbcnews.com/business", "priority": 10}
- task_id = await api.submit_crawl(request)
- result = await api.wait_for_task(task_id)
- print(f"Basic crawl result length: {len(result['result']['markdown'])}")
- assert result["status"] == "completed"
- assert "result" in result
- assert result["result"]["success"]
-
-
-async def test_js_execution():
- print("\n=== Testing JS Execution ===")
- async with NBCNewsAPITest() as api:
- request = {
- "urls": "https://www.nbcnews.com/business",
- "priority": 8,
- "js_code": [
- "const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More')); loadMoreButton && loadMoreButton.click();"
- ],
- "wait_for": "article.tease-card:nth-child(10)",
- "crawler_params": {"headless": True},
- }
- task_id = await api.submit_crawl(request)
- result = await api.wait_for_task(task_id)
- print(f"JS execution result length: {len(result['result']['markdown'])}")
- assert result["status"] == "completed"
- assert result["result"]["success"]
-
-
-async def test_css_selector():
- print("\n=== Testing CSS Selector ===")
- async with NBCNewsAPITest() as api:
- request = {
- "urls": "https://www.nbcnews.com/business",
- "priority": 7,
- "css_selector": ".wide-tease-item__description",
- }
- task_id = await api.submit_crawl(request)
- result = await api.wait_for_task(task_id)
- print(f"CSS selector result length: {len(result['result']['markdown'])}")
- assert result["status"] == "completed"
- assert result["result"]["success"]
-
-
-async def test_structured_extraction():
- print("\n=== Testing Structured Extraction ===")
- async with NBCNewsAPITest() as api:
- schema = {
- "name": "NBC News Articles",
- "baseSelector": "article.tease-card",
- "fields": [
- {"name": "title", "selector": "h2", "type": "text"},
- {
- "name": "description",
- "selector": ".tease-card__description",
- "type": "text",
- },
- {
- "name": "link",
- "selector": "a",
- "type": "attribute",
- "attribute": "href",
- },
- ],
- }
-
- request = {
- "urls": "https://www.nbcnews.com/business",
- "priority": 9,
- "extraction_config": {"type": "json_css", "params": {"schema": schema}},
- }
- task_id = await api.submit_crawl(request)
- result = await api.wait_for_task(task_id)
- extracted = json.loads(result["result"]["extracted_content"])
- print(f"Extracted {len(extracted)} articles")
- assert result["status"] == "completed"
- assert result["result"]["success"]
- assert len(extracted) > 0
-
-
-async def test_batch_crawl():
- print("\n=== Testing Batch Crawl ===")
- async with NBCNewsAPITest() as api:
- request = {
- "urls": [
- "https://www.nbcnews.com/business",
- "https://www.nbcnews.com/business/consumer",
- "https://www.nbcnews.com/business/economy",
- ],
- "priority": 6,
- "crawler_params": {"headless": True},
- }
- task_id = await api.submit_crawl(request)
- result = await api.wait_for_task(task_id)
- print(f"Batch crawl completed, got {len(result['results'])} results")
- assert result["status"] == "completed"
- assert "results" in result
- assert len(result["results"]) == 3
-
-
-async def test_llm_extraction():
- print("\n=== Testing LLM Extraction with Ollama ===")
- async with NBCNewsAPITest() as api:
- schema = {
- "type": "object",
- "properties": {
- "article_title": {
- "type": "string",
- "description": "The main title of the news article",
- },
- "summary": {
- "type": "string",
- "description": "A brief summary of the article content",
- },
- "main_topics": {
- "type": "array",
- "items": {"type": "string"},
- "description": "Main topics or themes discussed in the article",
- },
- },
- "required": ["article_title", "summary", "main_topics"],
- }
-
- request = {
- "urls": "https://www.nbcnews.com/business",
- "priority": 8,
- "extraction_config": {
- "type": "llm",
- "params": {
- "provider": "openai/gpt-4o-mini",
- "api_key": os.getenv("OLLAMA_API_KEY"),
- "schema": schema,
- "extraction_type": "schema",
- "instruction": """Extract the main article information including title, a brief summary, and main topics discussed.
- Focus on the primary business news article on the page.""",
- },
- },
- "crawler_params": {"headless": True, "word_count_threshold": 1},
- }
-
- task_id = await api.submit_crawl(request)
- result = await api.wait_for_task(task_id)
-
- if result["status"] == "completed":
- extracted = json.loads(result["result"]["extracted_content"])
- print("Extracted article analysis:")
- print(json.dumps(extracted, indent=2))
-
- assert result["status"] == "completed"
- assert result["result"]["success"]
-
-
-async def test_screenshot():
- print("\n=== Testing Screenshot ===")
- async with NBCNewsAPITest() as api:
- request = {
- "urls": "https://www.nbcnews.com/business",
- "priority": 5,
- "screenshot": True,
- "crawler_params": {"headless": True},
- }
- task_id = await api.submit_crawl(request)
- result = await api.wait_for_task(task_id)
- print("Screenshot captured:", bool(result["result"]["screenshot"]))
- assert result["status"] == "completed"
- assert result["result"]["success"]
- assert result["result"]["screenshot"] is not None
-
-
-async def test_priority_handling():
- print("\n=== Testing Priority Handling ===")
- async with NBCNewsAPITest() as api:
- # Submit low priority task first
- low_priority = {
- "urls": "https://www.nbcnews.com/business",
- "priority": 1,
- "crawler_params": {"headless": True},
- }
- low_task_id = await api.submit_crawl(low_priority)
-
- # Submit high priority task
- high_priority = {
- "urls": "https://www.nbcnews.com/business/consumer",
- "priority": 10,
- "crawler_params": {"headless": True},
- }
- high_task_id = await api.submit_crawl(high_priority)
-
- # Get both results
- high_result = await api.wait_for_task(high_task_id)
- low_result = await api.wait_for_task(low_task_id)
-
- print("Both tasks completed")
- assert high_result["status"] == "completed"
- assert low_result["status"] == "completed"
-
-
-async def main():
- try:
- # Start with health check
- async with NBCNewsAPITest() as api:
- health = await api.check_health()
- print("Server health:", health)
-
- # Run all tests
- # await test_basic_crawl()
- # await test_js_execution()
- # await test_css_selector()
- # await test_structured_extraction()
- await test_llm_extraction()
- # await test_batch_crawl()
- # await test_screenshot()
- # await test_priority_handling()
-
- except Exception as e:
- print(f"Test failed: {str(e)}")
- raise
-
-
-if __name__ == "__main__":
- asyncio.run(main())
diff --git a/tests/test_scraping_strategy.py b/tests/test_scraping_strategy.py
index df4628540..54be5de0d 100644
--- a/tests/test_scraping_strategy.py
+++ b/tests/test_scraping_strategy.py
@@ -1,26 +1,29 @@
-import nest_asyncio
+import sys
-nest_asyncio.apply()
+import pytest
-import asyncio
from crawl4ai import (
AsyncWebCrawler,
+ CacheMode,
CrawlerRunConfig,
LXMLWebScrapingStrategy,
- CacheMode,
)
-async def main():
+@pytest.mark.asyncio
+async def test_scraping_strategy():
config = CrawlerRunConfig(
cache_mode=CacheMode.BYPASS,
scraping_strategy=LXMLWebScrapingStrategy(), # Faster alternative to default BeautifulSoup
)
async with AsyncWebCrawler() as crawler:
result = await crawler.arun(url="https://example.com", config=config)
- print(f"Success: {result.success}")
- print(f"Markdown length: {len(result.markdown.raw_markdown)}")
+ assert result.success
+ assert result.markdown
+ assert result.markdown.raw_markdown
if __name__ == "__main__":
- asyncio.run(main())
+ import subprocess
+
+ sys.exit(subprocess.call(["pytest", *sys.argv[1:], sys.argv[0]]))
diff --git a/tests/test_web_crawler.py b/tests/test_web_crawler.py
index b84531924..f684f0acb 100644
--- a/tests/test_web_crawler.py
+++ b/tests/test_web_crawler.py
@@ -1,17 +1,23 @@
-import unittest, os
-from crawl4ai import LLMConfig
-from crawl4ai.web_crawler import WebCrawler
+import os
+import sys
+import unittest
+
+import pytest
+
+from crawl4ai import CacheMode
+from crawl4ai.async_configs import LLMConfig
from crawl4ai.chunking_strategy import (
- RegexChunking,
FixedLengthWordChunking,
+ RegexChunking,
SlidingWindowChunking,
+ TopicSegmentationChunking,
)
from crawl4ai.extraction_strategy import (
CosineStrategy,
LLMExtractionStrategy,
- TopicExtractionStrategy,
NoExtractionStrategy,
)
+from crawl4ai.legacy.web_crawler import WebCrawler
class TestWebCrawler(unittest.TestCase):
@@ -28,13 +34,16 @@ def test_run_default_strategies(self):
word_count_threshold=5,
chunking_strategy=RegexChunking(),
extraction_strategy=CosineStrategy(),
- bypass_cache=True,
+ cache_mode=CacheMode.BYPASS,
+ warmup=False,
)
self.assertTrue(
result.success, "Failed to crawl and extract using default strategies"
)
def test_run_different_strategies(self):
+ if not os.getenv("OPENAI_API_KEY"):
+ self.skipTest("Skipping env OPENAI_API_KEY not set")
url = "https://www.nbcnews.com/business"
# Test with FixedLengthWordChunking and LLMExtractionStrategy
@@ -45,7 +54,8 @@ def test_run_different_strategies(self):
extraction_strategy=LLMExtractionStrategy(
llm_config=LLMConfig(provider="openai/gpt-3.5-turbo", api_token=os.getenv("OPENAI_API_KEY"))
),
- bypass_cache=True,
+ cache_mode=CacheMode.BYPASS,
+ warmup=False
)
self.assertTrue(
result.success,
@@ -56,9 +66,11 @@ def test_run_different_strategies(self):
result = self.crawler.run(
url=url,
word_count_threshold=5,
- chunking_strategy=SlidingWindowChunking(window_size=100, step=50),
- extraction_strategy=TopicExtractionStrategy(num_keywords=5),
- bypass_cache=True,
+ chunking_strategy=TopicSegmentationChunking(
+ window_size=100, step=50, num_keywords=5
+ ),
+ cache_mode=CacheMode.BYPASS,
+ warmup=False,
)
self.assertTrue(
result.success,
@@ -66,37 +78,46 @@ def test_run_different_strategies(self):
)
def test_invalid_url(self):
- with self.assertRaises(Exception) as context:
- self.crawler.run(url="invalid_url", bypass_cache=True)
- self.assertIn("Invalid URL", str(context.exception))
+ result = self.crawler.run(
+ url="invalid_url", cache_mode=CacheMode.BYPASS, warmup=False
+ )
+ self.assertFalse(result.success, "Extraction should fail with invalid URL")
+ msg = "" if not result.error_message else result.error_message
+ self.assertTrue("invalid argument" in msg)
def test_unsupported_extraction_strategy(self):
- with self.assertRaises(Exception) as context:
- self.crawler.run(
- url="https://www.nbcnews.com/business",
- extraction_strategy="UnsupportedStrategy",
- bypass_cache=True,
- )
- self.assertIn("Unsupported extraction strategy", str(context.exception))
+ result = self.crawler.run(
+ url="https://www.nbcnews.com/business",
+ extraction_strategy="UnsupportedStrategy", # pyright: ignore[reportArgumentType]
+ cache_mode=CacheMode.BYPASS,
+ )
+ self.assertFalse(
+ result.success, "Extraction should fail with unsupported strategy"
+ )
+ self.assertEqual("Unsupported extraction strategy", result.error_message)
+ @pytest.mark.skip("Skipping InvalidCSSSelectorError is no longer raised")
def test_invalid_css_selector(self):
- with self.assertRaises(ValueError) as context:
- self.crawler.run(
- url="https://www.nbcnews.com/business",
- css_selector="invalid_selector",
- bypass_cache=True,
- )
- self.assertIn("Invalid CSS selector", str(context.exception))
+ result = self.crawler.run(
+ url="https://www.nbcnews.com/business",
+ css_selector="invalid_selector",
+ cache_mode=CacheMode.BYPASS,
+ warmup=False
+ )
+ self.assertFalse(
+ result.success, "Extraction should fail with invalid CSS selector"
+ )
+ self.assertEqual("Invalid CSS selector", result.error_message)
def test_crawl_with_cache_and_bypass_cache(self):
url = "https://www.nbcnews.com/business"
# First crawl with cache enabled
- result = self.crawler.run(url=url, bypass_cache=False)
+ result = self.crawler.run(url=url, bypass_cache=False, warmup=False)
self.assertTrue(result.success, "Failed to crawl and cache the result")
- # Second crawl with bypass_cache=True
- result = self.crawler.run(url=url, bypass_cache=True)
+ # Second crawl with cache_mode=CacheMode.BYPASS
+ result = self.crawler.run(url=url, cache_mode=CacheMode.BYPASS, warmup=False)
self.assertTrue(result.success, "Failed to bypass cache and fetch fresh data")
def test_fetch_multiple_pages(self):
@@ -108,7 +129,8 @@ def test_fetch_multiple_pages(self):
word_count_threshold=5,
chunking_strategy=RegexChunking(),
extraction_strategy=CosineStrategy(),
- bypass_cache=True,
+ cache_mode=CacheMode.BYPASS,
+ warmup=False
)
results.append(result)
@@ -124,7 +146,8 @@ def test_run_fixed_length_word_chunking_and_no_extraction(self):
word_count_threshold=5,
chunking_strategy=FixedLengthWordChunking(chunk_size=100),
extraction_strategy=NoExtractionStrategy(),
- bypass_cache=True,
+ cache_mode=CacheMode.BYPASS,
+ warmup=False,
)
self.assertTrue(
result.success,
@@ -137,7 +160,8 @@ def test_run_sliding_window_and_no_extraction(self):
word_count_threshold=5,
chunking_strategy=SlidingWindowChunking(window_size=100, step=50),
extraction_strategy=NoExtractionStrategy(),
- bypass_cache=True,
+ cache_mode=CacheMode.BYPASS,
+ warmup=False
)
self.assertTrue(
result.success,
@@ -146,4 +170,6 @@ def test_run_sliding_window_and_no_extraction(self):
if __name__ == "__main__":
- unittest.main()
+ import subprocess
+
+ sys.exit(subprocess.call(["pytest", *sys.argv[1:], sys.argv[0]]))