diff --git a/tests/test_docs.py b/tests/test_docs.py index bc93527b..d1543afd 100644 --- a/tests/test_docs.py +++ b/tests/test_docs.py @@ -10,6 +10,8 @@ import inline_snapshot._inline_snapshot import pytest +from inline_snapshot import snapshot +from inline_snapshot.extra import raises @dataclass @@ -20,113 +22,212 @@ class Block: line: int -def map_code_blocks(file): - def w(func): - - block_start = re.compile("( *)``` *python(.*)") - block_end = re.compile("```.*") - - header = re.compile("") - - current_code = file.read_text("utf-8") - new_lines = [] - block_lines = [] - options = set() - is_block = False - code = None - indent = "" - block_start_linenum = None - block_options = None - code_header = None - header_line = "" - - for linenumber, line in enumerate(current_code.splitlines(), start=1): - m = block_start.fullmatch(line) - if m and not is_block: - # ``` python - block_start_linenum = linenumber - indent = m[1] - block_options = m[2] - block_lines = [] - is_block = True - continue - - if block_end.fullmatch(line.strip()) and is_block: - # ``` - is_block = False - - code = "\n".join(block_lines) + "\n" - code = textwrap.dedent(code) - if file.suffix == ".py": - code = code.replace("\\\\", "\\") - - try: - new_block = func( - Block( - code=code, - code_header=code_header, - block_options=block_options, - line=block_start_linenum, - ) +def map_code_blocks(file, func, fix=False): + + block_start = re.compile("( *)``` *python(.*)") + block_end = re.compile("```.*") + + header = re.compile("") + + current_code = file.read_text("utf-8") + new_lines = [] + block_lines = [] + options = set() + is_block = False + code = None + indent = "" + block_start_linenum = None + block_options = None + code_header = None + header_line = "" + + for linenumber, line in enumerate(current_code.splitlines(), start=1): + m = block_start.fullmatch(line) + if m and not is_block: + # ``` python + block_start_linenum = linenumber + indent = m[1] + block_options = m[2] + block_lines = [] + is_block = True + continue + + if block_end.fullmatch(line.strip()) and is_block: + # ``` + is_block = False + + code = "\n".join(block_lines) + "\n" + code = textwrap.dedent(code) + if file.suffix == ".py": + code = code.replace("\\\\", "\\") + + try: + new_block = func( + Block( + code=code, + code_header=code_header, + block_options=block_options, + line=block_start_linenum, ) - except Exception: - print(f"error at block at line {block_start_linenum}") - print(f"{code_header=}") - print(f"{block_options=}") - print(code) - raise - - if new_block.code_header is not None: - new_lines.append( - f"{indent}" - ) - - new_lines.append( - f"{indent}``` {('python '+new_block.block_options.strip()).strip()}" ) + except Exception: + print(f"error at block at line {block_start_linenum}") + print(f"{code_header=}") + print(f"{block_options=}") + print(code) + raise + + if new_block.code_header is not None: + new_lines.append(f"{indent}") + + new_lines.append( + f"{indent}``` {('python '+new_block.block_options.strip()).strip()}" + ) - new_code = new_block.code.rstrip("\n") - if file.suffix == ".py": - new_code = new_code.replace("\\", "\\\\") - new_code = textwrap.indent(new_code, indent) + new_code = new_block.code.rstrip() + if file.suffix == ".py": + new_code = new_code.replace("\\", "\\\\") + new_code = textwrap.indent(new_code, indent) - new_lines.append(new_code) + new_lines.append(new_code) - new_lines.append(f"{indent}```") + new_lines.append(f"{indent}```") - header_line = "" + header_line = "" + code_header = None + + continue + + if is_block: + block_lines.append(line) + continue + + m = header.fullmatch(line.strip()) + if m: + # comment + header_line = line + code_header = m[1].strip() + continue + else: + if header_line: + new_lines.append(header_line) code_header = None + header_line = "" - continue + new_lines.append(line) - if is_block: - block_lines.append(line) - continue + new_code = "\n".join(new_lines) + "\n" - m = header.fullmatch(line.strip()) - if m: - # comment - header_line = line - code_header = m[1].strip() - continue - else: - if header_line: - new_lines.append(header_line) - code_header = None - header_line = "" + if fix: + file.write_text(new_code) + else: + assert current_code.splitlines() == new_code.splitlines() + assert current_code == new_code + + +def test_map_code_blocks(tmp_path): + + file = tmp_path / "example.md" + + def test_doc( + markdown_code, + handle_block=lambda block: exec(block.code), + blocks=[], + exception="", + new_markdown_code=None, + ): - if not is_block: - new_lines.append(line) + file.write_text(markdown_code) - new_code = "\n".join(new_lines) + "\n" + recorded_blocks = [] - if inline_snapshot._inline_snapshot._update_flags.fix: - file.write_text(new_code) + with raises(exception): + + def test_block(block): + handle_block(block) + recorded_blocks.append(block) + return block + + map_code_blocks(file, test_block, True) + assert recorded_blocks == blocks + map_code_blocks(file, test_block, False) + + recorded_markdown_code = file.read_text() + if recorded_markdown_code != markdown_code: + assert new_markdown_code == recorded_markdown_code else: - assert current_code.splitlines() == new_code.splitlines() - assert current_code == new_code + assert new_markdown_code == None - return w + test_doc( + """ +``` python +1 / 0 +``` +""", + exception=snapshot("ZeroDivisionError: division by zero"), + ) + + test_doc( + """\ +text +``` python +print(1 + 1) +``` +text + +``` python hl_lines="1 2 3" +print(1 - 1) +``` +text +""", + blocks=snapshot( + [ + Block( + code="print(1 + 1)\n", code_header=None, block_options="", line=2 + ), + Block( + code="print(1 - 1)\n", + code_header="inline-snapshot: create test", + block_options=' hl_lines="1 2 3"', + line=7, + ), + ] + ), + ) + + def change_block(block): + block.code = "# removed" + block.code_header = "header" + block.block_options = "option a b c" + + test_doc( + """\ +text +``` python +print(1 + 1) +``` +""", + handle_block=change_block, + blocks=snapshot( + [ + Block( + code="# removed", + code_header="header", + block_options="option a b c", + line=2, + ) + ] + ), + new_markdown_code=snapshot( + """\ +text + +``` python option a b c +# removed +``` +""" + ), + ) @pytest.mark.skipif( @@ -168,8 +269,7 @@ def test_docs(project, file, subtests): extra_files = defaultdict(list) - @map_code_blocks(file) - def _(block: Block): + def test_block(block: Block): if block.code_header is None: return block @@ -264,3 +364,7 @@ def _(block: Block): last_code = code return block + + map_code_blocks( + file, test_block, inline_snapshot._inline_snapshot._update_flags.fix + )