diff --git a/docs/customize_repr.md b/docs/customize_repr.md index 177548f6..4bde16c0 100644 --- a/docs/customize_repr.md +++ b/docs/customize_repr.md @@ -42,7 +42,7 @@ def test_enum(): inline-snapshot comes with a special implementation for the following types: -```python exec="1" +``` python exec="1" from inline_snapshot._code_repr import code_repr_dispatch, code_repr for name, obj in sorted( @@ -60,7 +60,7 @@ for name, obj in sorted( Container types like `dict` or `dataclass` need a special implementation because it is necessary that the implementation uses `repr()` for the child elements. -```python exec="1" result="python" +``` python exec="1" result="python" print('--8<-- "src/inline_snapshot/_code_repr.py:list"') ``` diff --git a/docs/eq_snapshot.md b/docs/eq_snapshot.md index 94b02f5f..49e62935 100644 --- a/docs/eq_snapshot.md +++ b/docs/eq_snapshot.md @@ -33,9 +33,31 @@ Example: def test_something(): assert 2 + 40 == snapshot(42) ``` +## unmanaged snapshot parts +inline-snapshots manages everything inside `snapshot(...)`, which means that the developer should not change these parts, but there are cases where it is useful to give the developer a bit more control over the snapshot content. -## dirty-equals +Therefor some types will be ignored by inline-snapshot and will **not be updated or fixed**, even if they cause tests to fail. + +These types are: + +* dirty-equals expression +* dynamic code inside `Is(...)` +* and snapshots inside snapshots. + +inline-snapshot is able to handle these types inside the following containers: + +* list +* tuple +* dict +* namedtuple +* dataclass + + +### dirty-equals It might be, that larger snapshots with many lists and dictionaries contain some values which change frequently and are not relevant for the test. They might be part of larger data structures and be difficult to normalize. @@ -82,7 +104,7 @@ Example: inline-snapshot tries to change only the values that it needs to change in order to pass the equality comparison. This allows to replace parts of the snapshot with [dirty-equals](https://dirty-equals.helpmanual.io/latest/) expressions. -This expressions are preserved as long as the `==` comparison with them is `True`. +This expressions are preserved even if the `==` comparison with them is `False`. Example: @@ -159,8 +181,149 @@ Example: ) ``` -!!! note - The current implementation looks only into lists, dictionaries and tuples and not into the representation of other data structures. +### Is(...) + +`Is()` can be used to put runtime values inside snapshots. +It tells inline-snapshot that the developer wants control over some part of the snapshot. + + +``` python +from inline_snapshot import snapshot, Is + +current_version = "1.5" + + +def request(): + return {"data": "page data", "version": current_version} + + +def test_function(): + assert request() == snapshot( + {"data": "page data", "version": Is(current_version)} + ) +``` + +The `current_version` can now be changed without having to correct the snapshot. + +`Is()` can also be used when the snapshot is evaluated multiple times. + +=== "original code" + + ``` python + from inline_snapshot import snapshot, Is + + + def test_function(): + for c in "abc": + assert [c, "correct"] == snapshot([Is(c), "wrong"]) + ``` + +=== "--inline-snapshot=fix" + + ``` python hl_lines="6" + from inline_snapshot import snapshot, Is + + + def test_function(): + for c in "abc": + assert [c, "correct"] == snapshot([Is(c), "correct"]) + ``` + +### inner snapshots + +Snapshots can be used inside other snapshots in different use cases. + +#### conditional snapshots +It is possible to describe version specific parts of snapshots by replacing the specific part with `#!python snapshot() if some_condition else snapshot()`. +The test has to be executed in each specific condition to fill the snapshots. + +The following example shows how this can be used to run a tests with two different library versions: + +=== "my_lib v1" + + + ``` python + version = 1 + + + def get_schema(): + return [{"name": "var_1", "type": "int"}] + ``` + +=== "my_lib v2" + + + ``` python + version = 2 + + + def get_schema(): + return [{"name": "var_1", "type": "string"}] + ``` + + + +``` python +from inline_snapshot import snapshot +from my_lib import version, get_schema + + +def test_function(): + assert get_schema() == snapshot( + [ + { + "name": "var_1", + "type": snapshot("int") if version < 2 else snapshot("string"), + } + ] + ) +``` + +The advantage of this approach is that the test uses always the correct values for each library version. + +#### common snapshot parts + +Another usecase is the extraction of common snapshot parts into an extra snapshot: + + +``` python +from inline_snapshot import snapshot + + +def some_data(name): + return {"header": "really long header\n" * 5, "your name": name} + + +def test_function(): + + header = snapshot( + """\ +really long header +really long header +really long header +really long header +really long header +""" + ) + + assert some_data("Tom") == snapshot( + { + "header": header, + "your name": "Tom", + } + ) + + assert some_data("Bob") == snapshot( + { + "header": header, + "your name": "Bob", + } + ) +``` + +This simplifies test data and allows inline-snapshot to update your values if required. +It makes also sure that the header is the same in both cases. + ## pytest options diff --git a/docs/pytest.md b/docs/pytest.md index 5d825618..a7d34aca 100644 --- a/docs/pytest.md +++ b/docs/pytest.md @@ -11,7 +11,7 @@ inline-snapshot provides one pytest option with different flags (*create*, Snapshot comparisons return always `True` if you use one of the flags *create*, *fix* or *review*. This is necessary because the whole test needs to be run to fix all snapshots like in this case: -```python +``` python from inline_snapshot import snapshot @@ -30,7 +30,7 @@ def test_something(): Approve the changes of the given [category](categories.md). These flags can be combined with *report* and *review*. -```python title="test_something.py" +``` python title="test_something.py" from inline_snapshot import snapshot diff --git a/pyproject.toml b/pyproject.toml index 2465a1bf..c15a7e02 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,13 +75,6 @@ omit = [ parallel = true source_pkgs = ["inline_snapshot", "tests"] -[tool.hatch.envs.coverage] -dependencies = [ - "coverage" -] -env-vars.TOP = "{root}" -scripts.report = "coverage html" - [tool.hatch.envs.docs] dependencies = [ "markdown-exec[ansi]>=1.8.0", @@ -119,6 +112,12 @@ extra-dependencies = [ ] env-vars.TOP = "{root}" +[tool.hatch.envs.hatch-test.scripts] +run = "pytest{env:HATCH_TEST_ARGS:} {args}" +run-cov = "coverage run -m pytest{env:HATCH_TEST_ARGS:} {args}" +cov-combine = "coverage combine" +cov-report=["coverage report","coverage html"] + [tool.hatch.envs.types] extra-dependencies = [ "mypy>=1.0.0", @@ -164,6 +163,7 @@ venvPath = ".nox" format = "md" version = "command: cz bump --get-next" -[tool.inline-snapshot.shortcuts] -sfix="create,fix" -review="create,review" +[tool.pytest.ini_options] +markers = [ + "no_rewriting: The test does not use the ast-nodes for rewriting", +] diff --git a/src/inline_snapshot/__init__.py b/src/inline_snapshot/__init__.py index 08cab885..74440cac 100644 --- a/src/inline_snapshot/__init__.py +++ b/src/inline_snapshot/__init__.py @@ -3,9 +3,19 @@ from ._external import external from ._external import outsource from ._inline_snapshot import snapshot +from ._is import Is from ._types import Category from ._types import Snapshot -__all__ = ["snapshot", "external", "outsource", "customize_repr", "HasRepr"] +__all__ = [ + "snapshot", + "external", + "outsource", + "customize_repr", + "HasRepr", + "Is", + "Category", + "Snapshot", +] __version__ = "0.14.0" diff --git a/src/inline_snapshot/_adapter/__init__.py b/src/inline_snapshot/_adapter/__init__.py new file mode 100644 index 00000000..2f699011 --- /dev/null +++ b/src/inline_snapshot/_adapter/__init__.py @@ -0,0 +1,3 @@ +from .adapter import get_adapter_type + +__all__ = ("get_adapter_type",) diff --git a/src/inline_snapshot/_adapter/adapter.py b/src/inline_snapshot/_adapter/adapter.py new file mode 100644 index 00000000..bd4d26b9 --- /dev/null +++ b/src/inline_snapshot/_adapter/adapter.py @@ -0,0 +1,71 @@ +from __future__ import annotations + +import ast +import typing + +from inline_snapshot._source_file import SourceFile + + +def get_adapter_type(value): + from inline_snapshot._adapter.dataclass_adapter import get_adapter_for_type + + adapter = get_adapter_for_type(type(value)) + if adapter is not None: + return adapter + + if isinstance(value, list): + from .sequence_adapter import ListAdapter + + return ListAdapter + + if type(value) is tuple: + from .sequence_adapter import TupleAdapter + + return TupleAdapter + + if isinstance(value, dict): + from .dict_adapter import DictAdapter + + return DictAdapter + + from .value_adapter import ValueAdapter + + return ValueAdapter + + +class Item(typing.NamedTuple): + value: typing.Any + node: ast.expr + + +class Adapter: + context: SourceFile + + def __init__(self, context): + self.context = context + + def get_adapter(self, old_value, new_value) -> Adapter: + if type(old_value) is not type(new_value): + from .value_adapter import ValueAdapter + + return ValueAdapter(self.context) + + adapter_type = get_adapter_type(old_value) + if adapter_type is not None: + return adapter_type(self.context) + assert False + + def assign(self, old_value, old_node, new_value): + raise NotImplementedError(cls) + + @classmethod + def map(cls, value, map_function): + raise NotImplementedError(cls) + + @classmethod + def repr(cls, value): + raise NotImplementedError(cls) + + +def adapter_map(value, map_function): + return get_adapter_type(value).map(value, map_function) diff --git a/src/inline_snapshot/_adapter/dataclass_adapter.py b/src/inline_snapshot/_adapter/dataclass_adapter.py new file mode 100644 index 00000000..767dcd57 --- /dev/null +++ b/src/inline_snapshot/_adapter/dataclass_adapter.py @@ -0,0 +1,327 @@ +from __future__ import annotations + +import ast +import warnings +from abc import ABC +from collections import defaultdict +from dataclasses import fields +from dataclasses import is_dataclass +from dataclasses import MISSING +from typing import Any + +from inline_snapshot._adapter.value_adapter import ValueAdapter + +from .._change import CallArg +from .._change import Delete +from ..syntax_warnings import InlineSnapshotSyntaxWarning +from .adapter import Adapter +from .adapter import adapter_map +from .adapter import Item + + +def get_adapter_for_type(typ): + subclasses = DataclassAdapter.__subclasses__() + options = [cls for cls in subclasses if cls.check_type(typ)] + # print(typ,options) + if not options: + return + + assert len(options) == 1 + return options[0] + + +class DataclassAdapter(Adapter): + + @classmethod + def check_type(cls, typ) -> bool: + raise NotImplementedError(cls) + + @classmethod + def arguments(cls, value) -> tuple[list[Any], dict[str, Any]]: + raise NotImplementedError(cls) + + @classmethod + def argument(cls, value, pos_or_name) -> Any: + raise NotImplementedError(cls) + + @classmethod + def repr(cls, value): + + args, kwargs = cls.arguments(value) + + arguments = [repr(value) for value in args] + [ + f"{key}={repr(value)}" for key, value in kwargs.items() + ] + + return f"{repr(type(value))}({', '.join(arguments)})" + + @classmethod + def map(cls, value, map_function): + new_args, new_kwargs = cls.arguments(value) + return type(value)( + *[adapter_map(arg, map_function) for arg in new_args], + **{k: adapter_map(kwarg, map_function) for k, kwarg in new_kwargs.items()}, + ) + + def items(self, value, node): + assert isinstance(node, ast.Call) + assert not node.args + assert all(kw.arg for kw in node.keywords) + + return [ + Item(value=self.argument(value, kw.arg), node=kw.value) + for kw in node.keywords + if kw.arg + ] + + def assign(self, old_value, old_node, new_value): + if old_node is None: + value = yield from ValueAdapter(self.context).assign( + old_value, old_node, new_value + ) + return value + + assert isinstance(old_node, ast.Call) + + # positional arguments + for pos_arg in old_node.args: + if isinstance(pos_arg, ast.Starred): + warnings.warn_explicit( + "star-expressions are not supported inside snapshots", + filename=self.context._source.filename, + lineno=pos_arg.lineno, + category=InlineSnapshotSyntaxWarning, + ) + return old_value + + # keyword arguments + for kw in old_node.keywords: + if kw.arg is None: + warnings.warn_explicit( + "star-expressions are not supported inside snapshots", + filename=self.context._source.filename, + lineno=kw.value.lineno, + category=InlineSnapshotSyntaxWarning, + ) + return old_value + + new_args, new_kwargs = self.arguments(new_value) + + # positional arguments + + result_args = [] + + for i, (new_value_element, node) in enumerate(zip(new_args, old_node.args)): + old_value_element = self.argument(old_value, i) + result = yield from self.get_adapter( + old_value_element, new_value_element + ).assign(old_value_element, node, new_value_element) + result_args.append(result) + + print(old_node.args) + print(new_args) + if len(old_node.args) > len(new_args): + for arg_pos, node in list(enumerate(old_node.args))[len(new_args) :]: + print("del", arg_pos) + yield Delete( + "fix", + self.context._source, + node, + self.argument(old_value, arg_pos), + ) + + if len(old_node.args) < len(new_args): + for insert_pos, value in list(enumerate(new_args))[len(old_node.args) :]: + yield CallArg( + flag="fix", + file=self.context._source, + node=old_node, + arg_pos=insert_pos, + arg_name=None, + new_code=self.context._value_to_code(value), + new_value=value, + ) + + # keyword arguments + result_kwargs = {} + for kw in old_node.keywords: + if not kw.arg in new_kwargs: + # delete entries + yield Delete( + "fix", + self.context._source, + kw.value, + self.argument(old_value, kw.arg), + ) + + old_node_kwargs = {kw.arg: kw.value for kw in old_node.keywords} + + to_insert = [] + insert_pos = 0 + for key, new_value_element in new_kwargs.items(): + if key not in old_node_kwargs: + # add new values + to_insert.append((key, new_value_element)) + result_kwargs[key] = new_value_element + else: + node = old_node_kwargs[key] + + # check values with same keys + old_value_element = self.argument(old_value, key) + result_kwargs[key] = yield from self.get_adapter( + old_value_element, new_value_element + ).assign(old_value_element, node, new_value_element) + + if to_insert: + for key, value in to_insert: + + yield CallArg( + flag="fix", + file=self.context._source, + node=old_node, + arg_pos=insert_pos, + arg_name=key, + new_code=self.context._value_to_code(value), + new_value=value, + ) + to_insert = [] + + insert_pos += 1 + + if to_insert: + + for key, value in to_insert: + + yield CallArg( + flag="fix", + file=self.context._source, + node=old_node, + arg_pos=insert_pos, + arg_name=key, + new_code=self.context._value_to_code(value), + new_value=value, + ) + + return type(old_value)(*result_args, **result_kwargs) + + +class DataclassContainer(DataclassAdapter): + + @classmethod + def check_type(cls, value): + return is_dataclass(value) + + @classmethod + def arguments(cls, value): + + kwargs = {} + + for field in fields(value): # type: ignore + if field.repr: + field_value = getattr(value, field.name) + + if field.default != MISSING and field.default == field_value: + continue + + if ( + field.default_factory != MISSING + and field.default_factory() == field_value + ): + continue + + kwargs[field.name] = field_value + + return ([], kwargs) + + def argument(self, value, pos_or_name): + assert isinstance(pos_or_name, str) + return getattr(value, pos_or_name) + + +try: + from pydantic import BaseModel +except ImportError: # pragma: no cover + pass +else: + + class PydanticContainer(DataclassAdapter): + + @classmethod + def check_type(cls, value): + return issubclass(value, BaseModel) + + @classmethod + def arguments(cls, value): + + return ( + [], + { + name: getattr(value, name) + for name, info in value.model_fields.items() + if getattr(value, name) != info.default + }, + ) + + def argument(self, value, pos_or_name): + assert isinstance(pos_or_name, str) + return getattr(value, pos_or_name) + + +class IsNamedTuple(ABC): + _inline_snapshot_name = "namedtuple" + + _fields: tuple + _field_defaults: dict + + @classmethod + def __subclasshook__(cls, t): + b = t.__bases__ + if len(b) != 1 or b[0] != tuple: + return False + f = getattr(t, "_fields", None) + if not isinstance(f, tuple): + return False + return all(type(n) == str for n in f) + + +class NamedTupleContainer(DataclassAdapter): + + @classmethod + def check_type(cls, value): + return issubclass(value, IsNamedTuple) + + @classmethod + def arguments(cls, value: IsNamedTuple): + + return ( + [], + { + field: getattr(value, field) + for field in value._fields + if field not in value._field_defaults + or getattr(value, field) != value._field_defaults[field] + }, + ) + + def argument(self, value, pos_or_name): + assert isinstance(pos_or_name, str) + return getattr(value, pos_or_name) + + +class DefaultDictContainer(DataclassAdapter): + @classmethod + def check_type(cls, value): + return issubclass(value, defaultdict) + + @classmethod + def arguments(cls, value: defaultdict): + + return ([value.default_factory, dict(value)], {}) + + def argument(self, value, pos_or_name): + assert isinstance(pos_or_name, int) + if pos_or_name == 0: + return value.default_factory + elif pos_or_name == 1: + return dict(value) + assert False diff --git a/src/inline_snapshot/_adapter/dict_adapter.py b/src/inline_snapshot/_adapter/dict_adapter.py new file mode 100644 index 00000000..4e0cf940 --- /dev/null +++ b/src/inline_snapshot/_adapter/dict_adapter.py @@ -0,0 +1,137 @@ +from __future__ import annotations + +import ast +import warnings + +from .._change import Delete +from .._change import DictInsert +from ..syntax_warnings import InlineSnapshotSyntaxWarning +from .adapter import Adapter +from .adapter import adapter_map +from .adapter import Item + + +class DictAdapter(Adapter): + + @classmethod + def repr(cls, value): + result = ( + "{" + + ", ".join(f"{repr(k)}: {repr(value)}" for k, value in value.items()) + + "}" + ) + + if type(value) is not dict: + result = f"{repr(type(value))}({result})" + + return result + + @classmethod + def map(cls, value, map_function): + return {k: adapter_map(v, map_function) for k, v in value.items()} + + def items(self, value, node): + if node is None: + return [Item(value=value, node=None) for value in value.values()] + + assert isinstance(node, ast.Dict) + + result = [] + + for value_key, node_key, node_value in zip( + value.keys(), node.keys, node.values + ): + try: + # this is just a sanity check, dicts should be ordered + node_key = ast.literal_eval(node_key) + except Exception: + pass + else: + assert node_key == value_key + + result.append(Item(value=value[value_key], node=node_value)) + + return result + + def assign(self, old_value, old_node, new_value): + if old_node is not None: + assert isinstance(old_node, ast.Dict) + assert len(old_value) == len(old_node.keys) + + for key, value in zip(old_node.keys, old_node.values): + if key is None: + warnings.warn_explicit( + "star-expressions are not supported inside snapshots", + filename=self.context._source.filename, + lineno=value.lineno, + category=InlineSnapshotSyntaxWarning, + ) + return old_value + + for value, node in zip(old_value.keys(), old_node.keys): + + try: + # this is just a sanity check, dicts should be ordered + node_value = ast.literal_eval(node) + except: + continue + assert node_value == value + + result = {} + for key, node in zip( + old_value.keys(), + (old_node.values if old_node is not None else [None] * len(old_value)), + ): + if not key in new_value: + # delete entries + yield Delete("fix", self.context._source, node, old_value[key]) + + to_insert = [] + insert_pos = 0 + for key, new_value_element in new_value.items(): + if key not in old_value: + # add new values + to_insert.append((key, new_value_element)) + result[key] = new_value_element + else: + if isinstance(old_node, ast.Dict): + node = old_node.values[list(old_value.keys()).index(key)] + else: + node = None + # check values with same keys + result[key] = yield from self.get_adapter( + old_value[key], new_value[key] + ).assign(old_value[key], node, new_value[key]) + + if to_insert: + new_code = [ + (self.context._value_to_code(k), self.context._value_to_code(v)) + for k, v in to_insert + ] + yield DictInsert( + "fix", + self.context._source, + old_node, + insert_pos, + new_code, + to_insert, + ) + to_insert = [] + + insert_pos += 1 + + if to_insert: + new_code = [ + (self.context._value_to_code(k), self.context._value_to_code(v)) + for k, v in to_insert + ] + yield DictInsert( + "fix", + self.context._source, + old_node, + len(old_value), + new_code, + to_insert, + ) + + return result diff --git a/src/inline_snapshot/_adapter/sequence_adapter.py b/src/inline_snapshot/_adapter/sequence_adapter.py new file mode 100644 index 00000000..b48a452c --- /dev/null +++ b/src/inline_snapshot/_adapter/sequence_adapter.py @@ -0,0 +1,114 @@ +from __future__ import annotations + +import ast +import warnings +from collections import defaultdict + +from .._align import add_x +from .._align import align +from .._change import Delete +from .._change import ListInsert +from .._compare_context import compare_context +from ..syntax_warnings import InlineSnapshotSyntaxWarning +from .adapter import Adapter +from .adapter import adapter_map +from .adapter import Item + + +class SequenceAdapter(Adapter): + node_type: type + value_type: type + braces: str + trailing_comma: bool + + @classmethod + def repr(cls, value): + if len(value) == 1 and cls.trailing_comma: + seq = repr(value[0]) + "," + else: + seq = ", ".join(map(repr, value)) + return cls.braces[0] + seq + cls.braces[1] + + @classmethod + def map(cls, value, map_function): + result = [adapter_map(v, map_function) for v in value] + return cls.value_type(result) + + def items(self, value, node): + if node is None: + return [Item(value=v, node=None) for v in value] + + assert isinstance(node, self.node_type), (node, self) + assert len(value) == len(node.elts) + + return [Item(value=v, node=n) for v, n in zip(value, node.elts)] + + def assign(self, old_value, old_node, new_value): + if old_node is not None: + assert isinstance( + old_node, ast.List if isinstance(old_value, list) else ast.Tuple + ) + + for e in old_node.elts: + if isinstance(e, ast.Starred): + warnings.warn_explicit( + "star-expressions are not supported inside snapshots", + filename=self.context.filename, + lineno=e.lineno, + category=InlineSnapshotSyntaxWarning, + ) + return old_value + + with compare_context(): + diff = add_x(align(old_value, new_value)) + old = zip( + old_value, + old_node.elts if old_node is not None else [None] * len(old_value), + ) + new = iter(new_value) + old_position = 0 + to_insert = defaultdict(list) + result = [] + for c in diff: + if c in "mx": + old_value_element, old_node_element = next(old) + new_value_element = next(new) + v = yield from self.get_adapter( + old_value_element, new_value_element + ).assign(old_value_element, old_node_element, new_value_element) + result.append(v) + old_position += 1 + elif c == "i": + new_value_element = next(new) + new_code = self.context._value_to_code(new_value_element) + result.append(new_value_element) + to_insert[old_position].append((new_code, new_value_element)) + elif c == "d": + old_value_element, old_node_element = next(old) + yield Delete( + "fix", self.context._source, old_node_element, old_value_element + ) + old_position += 1 + else: + assert False + + for position, code_values in to_insert.items(): + yield ListInsert( + "fix", self.context._source, old_node, position, *zip(*code_values) + ) + + return self.value_type(result) + + +class ListAdapter(SequenceAdapter): + node_type = ast.List + value_type = list + braces = "[]" + trailing_comma = False + + +class TupleAdapter(SequenceAdapter): + node_type = ast.Tuple + value_type = tuple + braces = "()" + trailing_comma = True diff --git a/src/inline_snapshot/_adapter/value_adapter.py b/src/inline_snapshot/_adapter/value_adapter.py new file mode 100644 index 00000000..f44d2358 --- /dev/null +++ b/src/inline_snapshot/_adapter/value_adapter.py @@ -0,0 +1,57 @@ +from __future__ import annotations + +from inline_snapshot._code_repr import value_code_repr +from inline_snapshot._unmanaged import Unmanaged +from inline_snapshot._unmanaged import update_allowed +from inline_snapshot._utils import value_to_token + +from .._change import Replace +from .adapter import Adapter + + +class ValueAdapter(Adapter): + + @classmethod + def repr(cls, value): + return value_code_repr(value) + + @classmethod + def map(cls, value, map_function): + return map_function(value) + + def assign(self, old_value, old_node, new_value): + # generic fallback + + # because IsStr() != IsStr() + if isinstance(old_value, Unmanaged): + return old_value + + if old_node is None: + new_token = [] + else: + new_token = value_to_token(new_value) + + if not old_value == new_value: + flag = "fix" + elif ( + old_node is not None + and update_allowed(old_value) + and self.context._token_of_node(old_node) != new_token + ): + flag = "update" + else: + # equal and equal repr + return old_value + + new_code = self.context._token_to_code(new_token) + + yield Replace( + node=old_node, + file=self.context._source, + new_code=new_code, + flag=flag, + old_value=old_value, + new_value=new_value, + ) + + return new_value diff --git a/src/inline_snapshot/_change.py b/src/inline_snapshot/_change.py index 691d4f7a..05c888f7 100644 --- a/src/inline_snapshot/_change.py +++ b/src/inline_snapshot/_change.py @@ -3,6 +3,7 @@ from dataclasses import dataclass from typing import Any from typing import cast +from typing import DefaultDict from typing import Dict from typing import List from typing import Optional @@ -11,7 +12,7 @@ from asttokens.util import Token from executing.executing import EnhancedAST -from executing.executing import Source +from inline_snapshot._source_file import SourceFile from ._rewrite_code import ChangeRecorder from ._rewrite_code import end_of @@ -21,11 +22,11 @@ @dataclass() class Change: flag: str - source: Source + file: SourceFile @property def filename(self): - return self.source.filename + return self.file.filename def apply(self): raise NotImplementedError() @@ -76,7 +77,7 @@ class Replace(Change): def apply(self): change = ChangeRecorder.current.new_change() - range = self.source.asttokens().get_text_positions(self.node, False) + range = self.file.asttokens().get_text_positions(self.node, False) change.replace(range, self.new_code, filename=self.filename) @@ -87,40 +88,21 @@ class CallArg(Change): arg_name: Optional[str] new_code: str - old_value: Any new_value: Any - def apply(self): - change = ChangeRecorder.current.new_change() - tokens = list(self.source.asttokens().get_tokens(self.node)) - - call = self.node - tokens = list(self.source.asttokens().get_tokens(call)) - assert isinstance(call, ast.Call) - assert len(call.args) == 0 - assert len(call.keywords) == 0 - assert tokens[-2].string == "(" - assert tokens[-1].string == ")" - - assert self.arg_pos == 0 - assert self.arg_name == None - - change = ChangeRecorder.current.new_change() - change.set_tags("inline_snapshot") - change.replace( - (end_of(tokens[-2]), start_of(tokens[-1])), - self.new_code, - filename=self.filename, - ) +TokenRange = Tuple[Token, Token] -TokenRange = Tuple[Token, Token] +def brace_tokens(source, node) -> TokenRange: + first_token, *_, end_token = source.asttokens().get_tokens(node) + return first_token, end_token def generic_sequence_update( - source: Source, - parent: Union[ast.List, ast.Tuple, ast.Dict], + source: SourceFile, + parent: Union[ast.List, ast.Tuple, ast.Dict, ast.Call], + brace_tokens: TokenRange, parent_elements: List[Union[TokenRange, None]], to_insert: Dict[int, List[str]], ): @@ -128,7 +110,7 @@ def generic_sequence_update( new_code = [] deleted = False - last_token, *_, end_token = source.asttokens().get_tokens(parent) + last_token, end_token = brace_tokens is_start = True elements = 0 @@ -169,7 +151,7 @@ def generic_sequence_update( code = ", " + code if elements == 1 and isinstance(parent, ast.Tuple): - # trailing comma for tuples (1,)i + # trailing comma for tuples (1,) code += "," rec.replace( @@ -180,21 +162,23 @@ def generic_sequence_update( def apply_all(all_changes: List[Change]): - by_parent: Dict[EnhancedAST, List[Union[Delete, DictInsert, ListInsert]]] = ( - defaultdict(list) - ) - sources: Dict[EnhancedAST, Source] = {} + by_parent: Dict[ + EnhancedAST, List[Union[Delete, DictInsert, ListInsert, CallArg]] + ] = defaultdict(list) + sources: Dict[EnhancedAST, SourceFile] = {} for change in all_changes: if isinstance(change, Delete): node = cast(EnhancedAST, change.node).parent + if isinstance(node, ast.keyword): + node = node.parent by_parent[node].append(change) - sources[node] = change.source + sources[node] = change.file - elif isinstance(change, (DictInsert, ListInsert)): + elif isinstance(change, (DictInsert, ListInsert, CallArg)): node = cast(EnhancedAST, change.node) by_parent[node].append(change) - sources[node] = change.source + sources[node] = change.file else: change.apply() @@ -218,11 +202,57 @@ def list_token_range(entry): generic_sequence_update( source, parent, + brace_tokens(source, parent), [None if e in to_delete else list_token_range(e) for e in parent.elts], to_insert, ) - elif isinstance(parent, (ast.Dict)): + elif isinstance(parent, ast.Call): + to_delete = { + change.node for change in changes if isinstance(change, Delete) + } + atok = source.asttokens() + + def arg_token_range(node): + if isinstance(node.parent, ast.keyword): + node = node.parent + r = list(atok.get_tokens(node)) + return r[0], r[-1] + + braces_left = atok.next_token(list(atok.get_tokens(parent.func))[-1]) + assert braces_left.string == "(" + braces_right = list(atok.get_tokens(parent))[-1] + assert braces_right.string == ")" + + to_insert = DefaultDict(list) + + for change in changes: + if isinstance(change, CallArg): + if change.arg_name is not None: + position = ( + change.arg_pos + if change.arg_pos is not None + else len(parent.args) + len(parent.keywords) + ) + to_insert[position].append( + f"{change.arg_name} = {change.new_code}" + ) + else: + assert change.arg_pos is not None + to_insert[change.arg_pos].append(change.new_code) + + generic_sequence_update( + source, + parent, + (braces_left, braces_right), + [ + None if e in to_delete else arg_token_range(e) + for e in parent.args + [kw.value for kw in parent.keywords] + ], + to_insert, + ) + + elif isinstance(parent, ast.Dict): to_delete = { change.node for change in changes if isinstance(change, Delete) } @@ -241,6 +271,7 @@ def dict_token_range(key, value): generic_sequence_update( source, parent, + brace_tokens(source, parent), [ None if value in to_delete else dict_token_range(key, value) for key, value in zip(parent.keys, parent.values) diff --git a/src/inline_snapshot/_code_repr.py b/src/inline_snapshot/_code_repr.py index 9a5dcd3a..0f878d77 100644 --- a/src/inline_snapshot/_code_repr.py +++ b/src/inline_snapshot/_code_repr.py @@ -1,14 +1,10 @@ import ast -from abc import ABC -from collections import defaultdict -from dataclasses import fields -from dataclasses import is_dataclass -from dataclasses import MISSING from enum import Enum from enum import Flag from functools import singledispatch from unittest import mock + real_repr = repr @@ -62,7 +58,7 @@ def customize_repr(f): """Register a funtion which should be used to get the code representation of a object. - ```python + ``` python @customize_repr def _(obj: MyCustomClass): return f"MyCustomClass(attr={repr(obj.attr)})" @@ -78,8 +74,27 @@ def _(obj: MyCustomClass): def code_repr(obj): - with mock.patch("builtins.repr", code_repr): - result = code_repr_dispatch(obj) + + with mock.patch("builtins.repr", mocked_code_repr): + return mocked_code_repr(obj) + + +def mocked_code_repr(obj): + from inline_snapshot._adapter.adapter import get_adapter_type + + adapter = get_adapter_type(obj) + assert adapter is not None + return adapter.repr(obj) + + +def value_code_repr(obj): + if not type(obj) == type(obj): + # dispatch will not work in cases like this + return ( + f"HasRepr({repr(type(obj))}, '< type(obj) can not be compared with == >')" + ) + + result = code_repr_dispatch(obj) try: ast.parse(result) @@ -104,59 +119,6 @@ def _(value: Flag): return " | ".join(f"{name}.{flag.name}" for flag in type(value) if flag in value) -# -8<- [start:list] -@customize_repr -def _(value: list): - return "[" + ", ".join(map(repr, value)) + "]" - - -# -8<- [end:list] - - -class OnlyTuple(ABC): - _inline_snapshot_name = "builtins.tuple" - - @classmethod - def __subclasshook__(cls, t): - return t is tuple - - -@customize_repr -def _(value: OnlyTuple): - assert isinstance(value, tuple) - if len(value) == 1: - return f"({repr(value[0])},)" - return "(" + ", ".join(map(repr, value)) + ")" - - -class IsNamedTuple(ABC): - _inline_snapshot_name = "namedtuple" - - _fields: tuple - _field_defaults: dict - - @classmethod - def __subclasshook__(cls, t): - b = t.__bases__ - if len(b) != 1 or b[0] != tuple: - return False - f = getattr(t, "_fields", None) - if not isinstance(f, tuple): - return False - return all(type(n) == str for n in f) - - -@customize_repr -def _(value: IsNamedTuple): - params = ", ".join( - f"{field}={repr(getattr(value,field))}" - for field in value._fields - if field not in value._field_defaults - or getattr(value, field) != value._field_defaults[field] - ) - return f"{repr(type(value))}({params})" - - @customize_repr def _(value: set): if len(value) == 0: @@ -173,71 +135,6 @@ def _(value: frozenset): return "frozenset({" + ", ".join(map(repr, value)) + "})" -@customize_repr -def _(value: dict): - result = ( - "{" + ", ".join(f"{repr(k)}: {repr(value)}" for k, value in value.items()) + "}" - ) - - if type(value) is not dict: - result = f"{repr(type(value))}({result})" - - return result - - -@customize_repr -def _(value: defaultdict): - return f"defaultdict({repr(value.default_factory)}, {repr(dict(value))})" - - @customize_repr def _(value: type): return value.__qualname__ - - -class IsDataclass(ABC): - _inline_snapshot_name = "dataclass" - - @classmethod - def __subclasshook__(cls, subclass): - return is_dataclass(subclass) - - -@customize_repr -def _(value: IsDataclass): - attrs = [] - for field in fields(value): # type: ignore - if field.repr: - field_value = getattr(value, field.name) - - if field.default != MISSING and field.default == field_value: - continue - - if ( - field.default_factory != MISSING - and field.default_factory() == field_value - ): - continue - - attrs.append(f"{field.name}={repr(field_value)}") - - return f"{repr(type(value))}({', '.join(attrs)})" - - -try: - from pydantic import BaseModel -except ImportError: # pragma: no cover - pass -else: - - @customize_repr - def _(model: BaseModel): - return ( - type(model).__qualname__ - + "(" - + ", ".join( - e + "=" + repr(getattr(model, e)) - for e in sorted(model.__pydantic_fields_set__) - ) - + ")" - ) diff --git a/src/inline_snapshot/_compare_context.py b/src/inline_snapshot/_compare_context.py new file mode 100644 index 00000000..104a235c --- /dev/null +++ b/src/inline_snapshot/_compare_context.py @@ -0,0 +1,17 @@ +from contextlib import contextmanager + + +def compare_only(): + return _eq_check_only + + +_eq_check_only = False + + +@contextmanager +def compare_context(): + global _eq_check_only + old_eq_only = _eq_check_only + _eq_check_only = True + yield + _eq_check_only = old_eq_only diff --git a/src/inline_snapshot/_inline_snapshot.py b/src/inline_snapshot/_inline_snapshot.py index 26c98478..62dcfcf0 100644 --- a/src/inline_snapshot/_inline_snapshot.py +++ b/src/inline_snapshot/_inline_snapshot.py @@ -1,21 +1,20 @@ import ast import copy import inspect -import tokenize -import warnings -from collections import defaultdict -from pathlib import Path from typing import Any from typing import Dict # noqa from typing import Iterator +from typing import List from typing import Set from typing import Tuple # noqa from typing import TypeVar from executing import Source +from inline_snapshot._adapter.adapter import Adapter +from inline_snapshot._adapter.adapter import adapter_map +from inline_snapshot._source_file import SourceFile -from ._align import add_x -from ._align import align +from ._adapter import get_adapter_type from ._change import CallArg from ._change import Change from ._change import Delete @@ -23,13 +22,14 @@ from ._change import ListInsert from ._change import Replace from ._code_repr import code_repr +from ._compare_context import compare_only from ._exceptions import UsageError -from ._format import format_code from ._sentinels import undefined from ._types import Category -from ._utils import ignore_tokens -from ._utils import normalize -from ._utils import simple_token +from ._types import Snapshot +from ._unmanaged import map_unmanaged +from ._unmanaged import Unmanaged +from ._unmanaged import update_allowed from ._utils import value_to_token @@ -37,7 +37,7 @@ class NotImplementedYet(Exception): pass -snapshots = {} # type: Dict[Tuple[int, int], Snapshot] +snapshots = {} # type: Dict[Tuple[int, int], SnapshotReference] _active = False @@ -86,36 +86,44 @@ def ignore_old_value(): return _update_flags.fix or _update_flags.update -class GenericValue: +class GenericValue(Snapshot): _new_value: Any _old_value: Any _current_op = "undefined" _ast_node: ast.Expr - _source: Source + _file: SourceFile - def _token_of_node(self, node): + def get_adapter(self, value): + return get_adapter_type(value)(self._file) - return list( - normalize( - [ - simple_token(t.type, t.string) - for t in self._source.asttokens().get_tokens(node) - if t.type not in ignore_tokens - ] - ) - ) + def _re_eval(self, value): - def _format(self, text): - if self._source is None: - return text - else: - return format_code(text, Path(self._source.filename)) + def re_eval(old_value, node, value): + if isinstance(old_value, Unmanaged): + old_value.value = value + return - def _token_to_code(self, tokens): - return self._format(tokenize.untokenize(tokens)).strip() + assert type(old_value) is type(value) + + adapter = self.get_adapter(old_value) + if adapter is not None and hasattr(adapter, "items"): + old_items = adapter.items(old_value, node) + new_items = adapter.items(value, node) + assert len(old_items) == len(new_items) + + for old_item, new_item in zip(old_items, new_items): + re_eval(old_item.value, old_item.node, new_item.value) + + else: + if update_allowed(old_value): + if not old_value == value: + raise UsageError( + "snapshot value should not change. Use Is(...) for dynamic snapshot parts." + ) + else: + assert False, "old_value should be converted to Unmanaged" - def _value_to_code(self, value): - return self._token_to_code(value_to_token(value)) + re_eval(self._old_value, self._ast_node, value) def _ignore_old(self): return ( @@ -169,10 +177,12 @@ def __getitem__(self, _item): class UndecidedValue(GenericValue): def __init__(self, old_value, ast_node, source): + + old_value = adapter_map(old_value, map_unmanaged) self._old_value = old_value self._new_value = undefined self._ast_node = ast_node - self._source = source + self._file = SourceFile(source) def _change(self, cls): self.__class__ = cls @@ -183,48 +193,28 @@ def _new_code(self): def _get_changes(self) -> Iterator[Change]: def handle(node, obj): - if isinstance(obj, list): - if not isinstance(node, ast.List): - return - for node_value, value in zip(node.elts, obj): - yield from handle(node_value, value) - elif isinstance(obj, tuple): - if not isinstance(node, ast.Tuple): - return - for node_value, value in zip(node.elts, obj): - yield from handle(node_value, value) - - elif isinstance(obj, dict): - if not isinstance(node, ast.Dict): - return - for value_key, node_key, node_value in zip( - obj.keys(), node.keys, node.values - ): - try: - # this is just a sanity check, dicts should be ordered - node_key = ast.literal_eval(node_key) - except Exception: - pass - else: - assert node_key == value_key - - yield from handle(node_value, obj[value_key]) - else: - if update_allowed(obj): - new_token = value_to_token(obj) - if self._token_of_node(node) != new_token: - new_code = self._token_to_code(new_token) - - yield Replace( - node=self._ast_node, - source=self._source, - new_code=new_code, - flag="update", - old_value=self._old_value, - new_value=self._old_value, - ) - if self._source is not None: + adapter = self.get_adapter(obj) + if adapter is not None and hasattr(adapter, "items"): + for item in adapter.items(obj, node): + yield from handle(item.node, item.value) + return + + if not isinstance(obj, Unmanaged): + new_token = value_to_token(obj) + if self._file._token_of_node(node) != new_token: + new_code = self._file._token_to_code(new_token) + + yield Replace( + node=self._ast_node, + file=self._file, + new_code=new_code, + flag="update", + old_value=self._old_value, + new_value=self._old_value, + ) + + if self._file._source is not None: yield from handle(self._ast_node, self._old_value) # functions which determine the type @@ -250,19 +240,6 @@ def __getitem__(self, item): return self[item] -try: - import dirty_equals # type: ignore -except ImportError: # pragma: no cover - - def update_allowed(value): - return True - -else: - - def update_allowed(value): - return not isinstance(value, dirty_equals.DirtyEquals) - - def clone(obj): new = copy.deepcopy(obj) if not obj == new: @@ -282,233 +259,39 @@ def clone(obj): class EqValue(GenericValue): _current_op = "x == snapshot" + _changes: List[Change] def __eq__(self, other): global _missing_values if self._old_value is undefined: _missing_values += 1 - def use_valid_old_values(old_value, new_value): - - if ( - isinstance(new_value, list) - and isinstance(old_value, list) - or isinstance(new_value, tuple) - and isinstance(old_value, tuple) - ): - diff = add_x(align(old_value, new_value)) - old = iter(old_value) - new = iter(new_value) - result = [] - for c in diff: - if c in "mx": - old_value_element = next(old) - new_value_element = next(new) - result.append( - use_valid_old_values(old_value_element, new_value_element) - ) - elif c == "i": - result.append(next(new)) - elif c == "d": - pass - else: - assert False - - return type(new_value)(result) - - elif isinstance(new_value, dict) and isinstance(old_value, dict): - result = {} - - for key, new_value_element in new_value.items(): - if key in old_value: - result[key] = use_valid_old_values( - old_value[key], new_value_element - ) - else: - result[key] = new_value_element - - return result - - if new_value == old_value: - return old_value - else: - return new_value - - if self._new_value is undefined: - self._new_value = use_valid_old_values(self._old_value, clone(other)) - if self._old_value is undefined or ignore_old_value(): - return True - return _return(self._old_value == other) - else: - return _return(self._new_value == other) + if not compare_only() and self._new_value is undefined: + adapter = Adapter(self._file).get_adapter(self._old_value, other) + it = iter(adapter.assign(self._old_value, self._ast_node, clone(other))) + self._changes = [] + while True: + try: + self._changes.append(next(it)) + except StopIteration as ex: + self._new_value = ex.value + break + + return _return(self._visible_value() == other) + + # if self._new_value is undefined: + # self._new_value = use_valid_old_values(self._old_value, clone(other)) + # if self._old_value is undefined or ignore_old_value(): + # return True + # return _return(self._old_value == other) + # else: + # return _return(self._new_value == other) def _new_code(self): - return self._value_to_code(self._new_value) + return self._file._value_to_code(self._new_value) def _get_changes(self) -> Iterator[Change]: - - assert self._old_value is not undefined - - def check(old_value, old_node, new_value): - - if ( - isinstance(old_node, ast.List) - and isinstance(new_value, list) - and isinstance(old_value, list) - or isinstance(old_node, ast.Tuple) - and isinstance(new_value, tuple) - and isinstance(old_value, tuple) - ): - for e in old_node.elts: - if isinstance(e, ast.Starred): - warnings.warn_explicit( - "star-expressions are not supported inside snapshots", - filename=self._source.filename, - lineno=e.lineno, - category=InlineSnapshotSyntaxWarning, - ) - return - diff = add_x(align(old_value, new_value)) - old = zip(old_value, old_node.elts) - new = iter(new_value) - old_position = 0 - to_insert = defaultdict(list) - for c in diff: - if c in "mx": - old_value_element, old_node_element = next(old) - new_value_element = next(new) - yield from check( - old_value_element, old_node_element, new_value_element - ) - old_position += 1 - elif c == "i": - new_value_element = next(new) - new_code = self._value_to_code(new_value_element) - to_insert[old_position].append((new_code, new_value_element)) - elif c == "d": - old_value_element, old_node_element = next(old) - yield Delete( - "fix", self._source, old_node_element, old_value_element - ) - old_position += 1 - else: - assert False - - for position, code_values in to_insert.items(): - yield ListInsert( - "fix", self._source, old_node, position, *zip(*code_values) - ) - - return - - elif ( - isinstance(old_node, ast.Dict) - and isinstance(new_value, dict) - and isinstance(old_value, dict) - and len(old_value) == len(old_node.keys) - ): - - for key, value in zip(old_node.keys, old_node.values): - if key is None: - warnings.warn_explicit( - "star-expressions are not supported inside snapshots", - filename=self._source.filename, - lineno=value.lineno, - category=InlineSnapshotSyntaxWarning, - ) - return - - for value, node in zip(old_value.keys(), old_node.keys): - assert node is not None - - try: - # this is just a sanity check, dicts should be ordered - node_value = ast.literal_eval(node) - except: - continue - assert node_value == value - - for key, node in zip(old_value.keys(), old_node.values): - if key in new_value: - # check values with same keys - yield from check(old_value[key], node, new_value[key]) - else: - # delete entries - yield Delete("fix", self._source, node, old_value[key]) - - to_insert = [] - insert_pos = 0 - for key, new_value_element in new_value.items(): - if key not in old_value: - # add new values - to_insert.append((key, new_value_element)) - else: - if to_insert: - new_code = [ - (self._value_to_code(k), self._value_to_code(v)) - for k, v in to_insert - ] - yield DictInsert( - "fix", - self._source, - old_node, - insert_pos, - new_code, - to_insert, - ) - to_insert = [] - insert_pos += 1 - - if to_insert: - new_code = [ - (self._value_to_code(k), self._value_to_code(v)) - for k, v in to_insert - ] - yield DictInsert( - "fix", - self._source, - old_node, - len(old_node.values), - new_code, - to_insert, - ) - - return - - # generic fallback - - # because IsStr() != IsStr() - if type(old_value) is type(new_value) and not update_allowed(new_value): - return - - if old_node is None: - new_token = [] - else: - new_token = value_to_token(new_value) - - if not old_value == new_value: - flag = "fix" - elif ( - self._ast_node is not None - and update_allowed(old_value) - and self._token_of_node(old_node) != new_token - ): - flag = "update" - else: - return - - new_code = self._token_to_code(new_token) - - yield Replace( - node=old_node, - source=self._source, - new_code=new_code, - flag=flag, - old_value=old_value, - new_value=new_value, - ) - - yield from check(self._old_value, self._ast_node, self._new_value) + return iter(self._changes) class MinMaxValue(GenericValue): @@ -535,7 +318,7 @@ def _generic_cmp(self, other): return _return(self.cmp(self._visible_value(), other)) def _new_code(self): - return self._value_to_code(self._new_value) + return self._file._value_to_code(self._new_value) def _get_changes(self) -> Iterator[Change]: new_token = value_to_token(self._new_value) @@ -545,17 +328,17 @@ def _get_changes(self) -> Iterator[Change]: flag = "trim" elif ( self._ast_node is not None - and self._token_of_node(self._ast_node) != new_token + and self._file._token_of_node(self._ast_node) != new_token ): flag = "update" else: return - new_code = self._token_to_code(new_token) + new_code = self._file._token_to_code(new_token) yield Replace( node=self._ast_node, - source=self._source, + file=self._file, new_code=new_code, flag=flag, old_value=self._old_value, @@ -625,7 +408,7 @@ def __contains__(self, item): return _return(item in self._old_value) def _new_code(self): - return self._value_to_code(self._new_value) + return self._file._value_to_code(self._new_value) def _get_changes(self) -> Iterator[Change]: @@ -638,19 +421,25 @@ def _get_changes(self) -> Iterator[Change]: for old_value, old_node in zip(self._old_value, elements): if old_value not in self._new_value: yield Delete( - flag="trim", source=self._source, node=old_node, old_value=old_value + flag="trim", + file=self._file, + node=old_node, + old_value=old_value, ) continue # check for update new_token = value_to_token(old_value) - if old_node is not None and self._token_of_node(old_node) != new_token: - new_code = self._token_to_code(new_token) + if ( + old_node is not None + and self._file._token_of_node(old_node) != new_token + ): + new_code = self._file._token_to_code(new_token) yield Replace( node=old_node, - source=self._source, + file=self._file, new_code=new_code, flag="update", old_value=old_value, @@ -661,10 +450,10 @@ def _get_changes(self) -> Iterator[Change]: if new_values: yield ListInsert( flag="fix", - source=self._source, + file=self._file, node=self._ast_node, position=len(self._old_value), - new_code=[self._value_to_code(v) for v in new_values], + new_code=[self._file._value_to_code(v) for v in new_values], new_values=new_values, ) @@ -678,31 +467,39 @@ def __getitem__(self, index): if self._new_value is undefined: self._new_value = {} - old_value = self._old_value - if old_value is undefined: - _missing_values += 1 - old_value = {} - - child_node = None - if self._ast_node is not None: - assert isinstance(self._ast_node, ast.Dict) - if index in old_value: - pos = list(old_value.keys()).index(index) - child_node = self._ast_node.values[pos] - if index not in self._new_value: + old_value = self._old_value + if old_value is undefined: + _missing_values += 1 + old_value = {} + + child_node = None + if self._ast_node is not None: + assert isinstance(self._ast_node, ast.Dict) + if index in old_value: + pos = list(old_value.keys()).index(index) + child_node = self._ast_node.values[pos] + self._new_value[index] = UndecidedValue( - old_value.get(index, undefined), child_node, self._source + old_value.get(index, undefined), child_node, self._file ) return self._new_value[index] + def _re_eval(self, value): + super()._re_eval(value) + + if self._new_value is not undefined and self._old_value is not undefined: + for key, s in self._new_value.items(): + if key in self._old_value: + s._re_eval(self._old_value[key]) + def _new_code(self): return ( "{" + ", ".join( [ - f"{self._value_to_code(k)}: {v._new_code()}" + f"{self._file._value_to_code(k)}: {v._new_code()}" for k, v in self._new_value.items() if not isinstance(v, UndecidedValue) ] @@ -726,7 +523,7 @@ def _get_changes(self) -> Iterator[Change]: yield from self._new_value[key]._get_changes() else: # delete entries - yield Delete("trim", self._source, node, self._old_value[key]) + yield Delete("trim", self._file, node, self._old_value[key]) to_insert = [] for key, new_value_element in self._new_value.items(): @@ -737,10 +534,10 @@ def _get_changes(self) -> Iterator[Change]: to_insert.append((key, new_value_element._new_code())) if to_insert: - new_code = [(self._value_to_code(k), v) for k, v in to_insert] + new_code = [(self._file._value_to_code(k), v) for k, v in to_insert] yield DictInsert( "create", - self._source, + self._file, self._ast_node, len(self._old_value), new_code, @@ -815,10 +612,12 @@ def snapshot(obj: Any = undefined) -> Any: node = expr.node if node is None: # we can run without knowing of the calling expression but we will not be able to fix code - snapshots[key] = Snapshot(obj, None) + snapshots[key] = SnapshotReference(obj, None) else: assert isinstance(node, ast.Call) - snapshots[key] = Snapshot(obj, expr) + snapshots[key] = SnapshotReference(obj, expr) + else: + snapshots[key]._re_eval(obj) return snapshots[key]._value @@ -835,7 +634,7 @@ def used_externals(tree): ] -class Snapshot: +class SnapshotReference: def __init__(self, value, expr): self._expr = expr node = expr.node.args[0] if expr is not None and expr.node.args else None @@ -853,16 +652,18 @@ def _changes(self): new_code = self._value._new_code() yield CallArg( - "create", - self._value._source, - self._expr.node if self._expr is not None else None, - 0, - None, - new_code, - self._value._old_value, - self._value._new_value, + flag="create", + file=self._value._file, + node=self._expr.node if self._expr is not None else None, + arg_pos=0, + arg_name=None, + new_code=new_code, + new_value=self._value._new_value, ) else: yield from self._value._get_changes() + + def _re_eval(self, obj): + self._value._re_eval(obj) diff --git a/src/inline_snapshot/_is.py b/src/inline_snapshot/_is.py new file mode 100644 index 00000000..1f695397 --- /dev/null +++ b/src/inline_snapshot/_is.py @@ -0,0 +1,6 @@ +class Is: + def __init__(self, value): + self.value = value + + def __eq__(self, other): + return self.value == other diff --git a/src/inline_snapshot/_rewrite_code.py b/src/inline_snapshot/_rewrite_code.py index 70eb9b6e..0cab4c56 100644 --- a/src/inline_snapshot/_rewrite_code.py +++ b/src/inline_snapshot/_rewrite_code.py @@ -98,12 +98,8 @@ def __init__(self, change_recorder): self.change_recorder._changes.append(self) self.change_id = self._next_change_id - self._tags = [] type(self)._next_change_id += 1 - def set_tags(self, *tags): - self._tags = tags - def replace(self, node, new_contend, *, filename): assert isinstance(new_contend, str) @@ -128,7 +124,7 @@ def _replace(self, filename, range, new_contend): class SourceFile: - def __init__(self, filename): + def __init__(self, filename: pathlib.Path): self.replacements: list[Replacement] = [] self.filename = filename self.source = self.filename.read_text("utf-8") diff --git a/src/inline_snapshot/_source_file.py b/src/inline_snapshot/_source_file.py new file mode 100644 index 00000000..ba8a94bc --- /dev/null +++ b/src/inline_snapshot/_source_file.py @@ -0,0 +1,51 @@ +import tokenize +from pathlib import Path + +from executing import Source +from inline_snapshot._format import format_code +from inline_snapshot._utils import normalize +from inline_snapshot._utils import simple_token +from inline_snapshot._utils import value_to_token + +from ._utils import ignore_tokens + + +class SourceFile: + _source = Source + + def __init__(self, source): + if isinstance(source, SourceFile): + self._source = source._source + else: + self._source = source + + @property + def filename(self): + return self._source.filename + + def _format(self, text): + if self._source is None: + return text + else: + return format_code(text, Path(self._source.filename)) + + def asttokens(self): + return self._source.asttokens() + + def _token_to_code(self, tokens): + return self._format(tokenize.untokenize(tokens)).strip() + + def _value_to_code(self, value): + return self._token_to_code(value_to_token(value)) + + def _token_of_node(self, node): + + return list( + normalize( + [ + simple_token(t.type, t.string) + for t in self._source.asttokens().get_tokens(node) + if t.type not in ignore_tokens + ] + ) + ) diff --git a/src/inline_snapshot/_unmanaged.py b/src/inline_snapshot/_unmanaged.py new file mode 100644 index 00000000..5e46b9b5 --- /dev/null +++ b/src/inline_snapshot/_unmanaged.py @@ -0,0 +1,41 @@ +from ._is import Is +from ._types import Snapshot + +try: + import dirty_equals # type: ignore +except ImportError: # pragma: no cover + + def is_dirty_equal(value): + return False + +else: + + def is_dirty_equal(value): + return isinstance(value, dirty_equals.DirtyEquals) or ( + isinstance(value, type) and issubclass(value, dirty_equals.DirtyEquals) + ) + + +def update_allowed(value): + return not (is_dirty_equal(value) or isinstance(value, (Is, Snapshot))) # type: ignore + + +def is_unmanaged(value): + return not update_allowed(value) + + +class Unmanaged: + def __init__(self, value): + self.value = value + + def __eq__(self, other): + assert not isinstance(other, Unmanaged) + + return self.value == other + + +def map_unmanaged(value): + if is_unmanaged(value): + return Unmanaged(value) + else: + return value diff --git a/src/inline_snapshot/syntax_warnings.py b/src/inline_snapshot/syntax_warnings.py new file mode 100644 index 00000000..35dc21a0 --- /dev/null +++ b/src/inline_snapshot/syntax_warnings.py @@ -0,0 +1,2 @@ +class InlineSnapshotSyntaxWarning(Warning): + pass diff --git a/src/inline_snapshot/testing/_example.py b/src/inline_snapshot/testing/_example.py index b31627bd..0e377f91 100644 --- a/src/inline_snapshot/testing/_example.py +++ b/src/inline_snapshot/testing/_example.py @@ -5,6 +5,7 @@ import platform import re import subprocess as sp +import traceback from argparse import ArgumentParser from pathlib import Path from tempfile import TemporaryDirectory @@ -85,6 +86,14 @@ def __init__(self, files: str | dict[str, str]): self.files = files + self.dump_files() + + def dump_files(self): + for name, content in self.files.items(): + print(f"file: {name}") + print(content) + print() + def _write_files(self, dir: Path): for name, content in self.files.items(): (dir / name).write_text(content) @@ -151,6 +160,7 @@ def run_inline( try: for filename in tmp_path.glob("*.py"): globals: dict[str, Any] = {} + print("run> pytest", filename) exec( compile(filename.read_text("utf-8"), filename, "exec"), globals, @@ -161,6 +171,7 @@ def run_inline( if k.startswith("test_") and callable(v): v() except Exception as e: + traceback.print_exc() raised_exception = e finally: diff --git a/tests/adapter/test_dataclass.py b/tests/adapter/test_dataclass.py new file mode 100644 index 00000000..a0b1176a --- /dev/null +++ b/tests/adapter/test_dataclass.py @@ -0,0 +1,391 @@ +from inline_snapshot import snapshot +from inline_snapshot.testing._example import Example + +from tests.warns import warns + + +def test_unmanaged(): + + Example( + """\ +from inline_snapshot import snapshot,Is +from dataclasses import dataclass + +@dataclass +class A: + a:int + b:int + +def test_something(): + assert A(a=2,b=4) == snapshot(A(a=1,b=Is(1))), "not equal" +""" + ).run_inline( + ["--inline-snapshot=fix"], + changed_files=snapshot( + { + "test_something.py": """\ +from inline_snapshot import snapshot,Is +from dataclasses import dataclass + +@dataclass +class A: + a:int + b:int + +def test_something(): + assert A(a=2,b=4) == snapshot(A(a=2,b=Is(1))), "not equal" +""" + } + ), + raises=snapshot( + """\ +AssertionError: +not equal\ +""" + ), + ) + + +def test_reeval(): + Example( + """\ +from inline_snapshot import snapshot,Is +from dataclasses import dataclass + +@dataclass +class A: + a:int + b:int + +def test_something(): + for c in "ab": + assert A(a=1,b=c) == snapshot(A(a=2,b=Is(c))) +""" + ).run_inline( + ["--inline-snapshot=fix"], + changed_files=snapshot( + { + "test_something.py": """\ +from inline_snapshot import snapshot,Is +from dataclasses import dataclass + +@dataclass +class A: + a:int + b:int + +def test_something(): + for c in "ab": + assert A(a=1,b=c) == snapshot(A(a=1,b=Is(c))) +""" + } + ), + ) + + +def test_default_value(): + Example( + """\ +from inline_snapshot import snapshot,Is +from dataclasses import dataclass,field + +@dataclass +class A: + a:int + b:int=2 + c:int=field(default_factory=list) + +def test_something(): + for c in "ab": + assert A(a=c) == snapshot(A(a=Is(c),b=2,c=[])) +""" + ).run_inline( + ["--inline-snapshot=fix"], + changed_files=snapshot( + { + "test_something.py": """\ +from inline_snapshot import snapshot,Is +from dataclasses import dataclass,field + +@dataclass +class A: + a:int + b:int=2 + c:int=field(default_factory=list) + +def test_something(): + for c in "ab": + assert A(a=c) == snapshot(A(a=Is(c))) +""" + } + ), + ) + + +def test_disabled(executing_used): + Example( + """\ +from inline_snapshot import snapshot +from dataclasses import dataclass + +@dataclass +class A: + a:int + +def test_something(): + assert A(a=3) == snapshot(A(a=5)),"not equal" +""" + ).run_inline( + changed_files=snapshot({}), + raises=snapshot( + """\ +AssertionError: +not equal\ +""" + ), + ) + + +def test_starred_warns(): + with warns( + snapshot( + [ + ( + 10, + "InlineSnapshotSyntaxWarning: star-expressions are not supported inside snapshots", + ) + ] + ), + include_line=True, + ): + Example( + """ +from inline_snapshot import snapshot +from dataclasses import dataclass + +@dataclass +class A: + a:int + +def test_something(): + assert A(a=3) == snapshot(A(**{"a":5})),"not equal" +""" + ).run_inline( + ["--inline-snapshot=fix"], + raises=snapshot( + """\ +AssertionError: +not equal\ +""" + ), + ) + + +def test_add_argument(): + Example( + """\ +from inline_snapshot import snapshot +from dataclasses import dataclass + +@dataclass +class A: + a:int=0 + b:int=0 + c:int=0 + +def test_something(): + assert A(a=3,b=3,c=3) == snapshot(A(b=3)),"not equal" +""" + ).run_inline( + ["--inline-snapshot=fix"], + changed_files=snapshot( + { + "test_something.py": """\ +from inline_snapshot import snapshot +from dataclasses import dataclass + +@dataclass +class A: + a:int=0 + b:int=0 + c:int=0 + +def test_something(): + assert A(a=3,b=3,c=3) == snapshot(A(a = 3, b=3, c = 3)),"not equal" +""" + } + ), + raises=snapshot( + """\ +AssertionError: +not equal\ +""" + ), + ) + + +def test_positional_star_args(): + + with warns( + snapshot( + [ + "InlineSnapshotSyntaxWarning: star-expressions are not supported inside snapshots" + ] + ) + ): + Example( + """\ +from inline_snapshot import snapshot +from dataclasses import dataclass + +@dataclass +class A: + a:int + +def test_something(): + assert A(a=3) == snapshot(A(*[],a=3)),"not equal" +""" + ).run_inline( + ["--inline-snapshot=report"], + ) + + +def test_remove_positional_argument(): + Example( + """\ +from inline_snapshot import snapshot + +from inline_snapshot._adapter.dataclass_adapter import DataclassAdapter + + +class L: + def __init__(self,*l): + self.l=l + + def __eq__(self,other): + if not isinstance(other,L): + return NotImplemented + return other.l==self.l + +class LAdapter(DataclassAdapter): + @classmethod + def check_type(cls, typ): + return issubclass(typ,L) + + @classmethod + def arguments(cls, value): + return (value.l,{}) + + @classmethod + def argument(cls, value, pos_or_name): + assert isinstance(pos_or_name,int) + return value.l[pos_or_name] + +def test_L1(): + assert L(1,2) == snapshot(L(1)), "not equal" + +def test_L2(): + assert L(1,2) == snapshot(L(1, 2, 3)), "not equal" +""" + ).run_pytest( + ["--inline-snapshot=fix"], + changed_files=snapshot( + { + "test_something.py": """\ +from inline_snapshot import snapshot + +from inline_snapshot._adapter.dataclass_adapter import DataclassAdapter + + +class L: + def __init__(self,*l): + self.l=l + + def __eq__(self,other): + if not isinstance(other,L): + return NotImplemented + return other.l==self.l + +class LAdapter(DataclassAdapter): + @classmethod + def check_type(cls, typ): + return issubclass(typ,L) + + @classmethod + def arguments(cls, value): + return (value.l,{}) + + @classmethod + def argument(cls, value, pos_or_name): + assert isinstance(pos_or_name,int) + return value.l[pos_or_name] + +def test_L1(): + assert L(1,2) == snapshot(L(1, 2)), "not equal" + +def test_L2(): + assert L(1,2) == snapshot(L(1, 2)), "not equal" +""" + } + ), + ) + + +def test_namedtuple(): + Example( + """\ +from inline_snapshot import snapshot +from collections import namedtuple + +T=namedtuple("T","a,b") + +def test_tuple(): + assert T(a=1,b=2) == snapshot(T(a=1, b=3)), "not equal" +""" + ).run_pytest( + ["--inline-snapshot=fix"], + changed_files=snapshot( + { + "test_something.py": """\ +from inline_snapshot import snapshot +from collections import namedtuple + +T=namedtuple("T","a,b") + +def test_tuple(): + assert T(a=1,b=2) == snapshot(T(a=1, b=2)), "not equal" +""" + } + ), + ) + + +def test_defaultdict(): + Example( + """\ +from inline_snapshot import snapshot +from collections import defaultdict + + +def test_tuple(): + d=defaultdict(list) + d[1].append(2) + assert d == snapshot(defaultdict(list, {1: [3]})), "not equal" +""" + ).run_pytest( + ["--inline-snapshot=fix"], + changed_files=snapshot( + { + "test_something.py": """\ +from inline_snapshot import snapshot +from collections import defaultdict + + +def test_tuple(): + d=defaultdict(list) + d[1].append(2) + assert d == snapshot(defaultdict(list, {1: [2]})), "not equal" +""" + } + ), + ) diff --git a/tests/adapter/test_dict.py b/tests/adapter/test_dict.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/adapter/test_general.py b/tests/adapter/test_general.py new file mode 100644 index 00000000..52d1daf9 --- /dev/null +++ b/tests/adapter/test_general.py @@ -0,0 +1,47 @@ +from inline_snapshot import snapshot +from inline_snapshot.testing import Example + + +def test_adapter_mismatch(): + + Example( + """\ +from inline_snapshot import snapshot + + +def test_thing(): + assert [1,2] == snapshot({1:2}) + + """ + ).run_inline( + ["--inline-snapshot=fix"], + changed_files=snapshot( + { + "test_something.py": """\ +from inline_snapshot import snapshot + + +def test_thing(): + assert [1,2] == snapshot([1, 2]) + + \ +""" + } + ), + ) + + +def test_reeval(): + + Example( + """\ +from inline_snapshot import snapshot,Is + + +def test_thing(): + for i in (1,2): + assert {1:i} == snapshot({1:Is(i)}) + assert [i] == [Is(i)] + assert (i,) == (Is(i),) +""" + ).run_pytest(["--inline-snapshot=short-report"], report=snapshot("")) diff --git a/tests/adapter/test_sequence.py b/tests/adapter/test_sequence.py new file mode 100644 index 00000000..77ea2853 --- /dev/null +++ b/tests/adapter/test_sequence.py @@ -0,0 +1,94 @@ +import pytest +from inline_snapshot._inline_snapshot import snapshot +from inline_snapshot.testing._example import Example + + +def test_list_adapter_create_inner_snapshot(): + + Example( + """\ +from inline_snapshot import snapshot +from dirty_equals import IsInt + +def test_list(): + + assert [1,2,3,4] == snapshot([1,IsInt(),snapshot(),4]),"not equal" +""" + ).run_inline( + ["--inline-snapshot=create"], + changed_files=snapshot( + { + "test_something.py": """\ +from inline_snapshot import snapshot +from dirty_equals import IsInt + +def test_list(): + + assert [1,2,3,4] == snapshot([1,IsInt(),snapshot(3),4]),"not equal" +""" + } + ), + raises=snapshot( + """\ +AssertionError: +not equal\ +""" + ), + ) + + +def test_list_adapter_fix_inner_snapshot(): + + Example( + """\ +from inline_snapshot import snapshot +from dirty_equals import IsInt + +def test_list(): + + assert [1,2,3,4] == snapshot([1,IsInt(),snapshot(8),4]),"not equal" +""" + ).run_inline( + ["--inline-snapshot=fix"], + changed_files=snapshot( + { + "test_something.py": """\ +from inline_snapshot import snapshot +from dirty_equals import IsInt + +def test_list(): + + assert [1,2,3,4] == snapshot([1,IsInt(),snapshot(3),4]),"not equal" +""" + } + ), + raises=snapshot( + """\ +AssertionError: +not equal\ +""" + ), + ) + + +@pytest.mark.no_rewriting +def test_list_adapter_reeval(executing_used): + + Example( + """\ +from inline_snapshot import snapshot,Is + +def test_list(): + + for i in (1,2,3): + assert [1,i] == snapshot([1,Is(i)]),"not equal" +""" + ).run_inline( + changed_files=snapshot({}), + raises=snapshot( + """\ +AssertionError: +not equal\ +""" + ), + ) diff --git a/tests/conftest.py b/tests/conftest.py index 4aeb5730..dbef0a28 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -21,6 +21,7 @@ from inline_snapshot._format import format_code from inline_snapshot._inline_snapshot import Flags from inline_snapshot._rewrite_code import ChangeRecorder +from inline_snapshot._types import Category from inline_snapshot.testing._example import snapshot_env pytest_plugins = "pytester" @@ -65,7 +66,7 @@ def w(source_code, *, flags="", reported_flags=None, number=1): @pytest.fixture() -def source(tmp_path): +def source(tmp_path: Path): filecount = 1 @dataclass @@ -76,8 +77,8 @@ class Source: number_snapshots: int = 0 number_changes: int = 0 - def run(self, *flags): - flags = Flags({*flags}) + def run(self, *flags_arg: Category): + flags = Flags({*flags_arg}) nonlocal filecount filename: Path = tmp_path / f"test_{filecount}.py" @@ -311,7 +312,10 @@ def format(self): ) def pyproject(self, source): - (pytester.path / "pyproject.toml").write_text(source, "utf-8") + self.write_file("pyproject.toml", source) + + def write_file(self, filename, content): + (pytester.path / filename).write_text(content, "utf-8") def storage(self): dir = pytester.path / ".inline-snapshot" / "external" diff --git a/tests/test_change.py b/tests/test_change.py new file mode 100644 index 00000000..cbe82589 --- /dev/null +++ b/tests/test_change.py @@ -0,0 +1,90 @@ +import ast + +import pytest +from executing import Source +from inline_snapshot._change import apply_all +from inline_snapshot._change import CallArg +from inline_snapshot._change import Delete +from inline_snapshot._change import Replace +from inline_snapshot._inline_snapshot import snapshot +from inline_snapshot._rewrite_code import ChangeRecorder +from inline_snapshot._source_file import SourceFile + + +@pytest.fixture +def check_change(tmp_path): + i = 0 + + def w(source, changes, new_code): + nonlocal i + + filename = tmp_path / f"test_{i}.py" + i += 1 + + filename.write_text(source) + print(f"\ntest: {source}") + + source = Source.for_filename(filename) + module = source.tree + context = SourceFile(source) + + call = module.body[0].value + assert isinstance(call, ast.Call) + + with ChangeRecorder().activate() as cr: + apply_all(changes(context, call)) + + cr.virtual_write() + + cr.dump() + + assert list(cr.files())[0].source == new_code + + return w + + +def test_change_function_args(check_change): + + check_change( + "f(a,b=2)", + lambda source, call: [ + Replace( + flag="fix", + file=source, + node=call.args[0], + new_code="22", + old_value=0, + new_value=0, + ) + ], + snapshot("f(22,b=2)"), + ) + + check_change( + "f(a,b=2)", + lambda source, call: [ + Delete( + flag="fix", + file=source, + node=call.args[0], + old_value=0, + ) + ], + snapshot("f(b=2)"), + ) + + check_change( + "f(a,b=2)", + lambda source, call: [ + CallArg( + flag="fix", + file=source, + node=call, + arg_pos=0, + arg_name=None, + new_code="22", + new_value=22, + ) + ], + snapshot("f(22, a,b=2)"), + ) diff --git a/tests/test_dirty_equals.py b/tests/test_dirty_equals.py new file mode 100644 index 00000000..1bc897a3 --- /dev/null +++ b/tests/test_dirty_equals.py @@ -0,0 +1,152 @@ +from inline_snapshot._inline_snapshot import snapshot +from inline_snapshot.testing._example import Example + + +def test_dirty_equals_repr(): + Example( + """\ +from inline_snapshot import snapshot +from dirty_equals import IsStr + +def test_something(): + assert [IsStr()] == snapshot() + """ + ).run_inline( + ["--inline-snapshot=create"], + changed_files=snapshot({}), + raises=snapshot( + """\ +UsageError: +inline-snapshot uses `copy.deepcopy` to copy objects, +but the copied object is not equal to the original one: + +original: [HasRepr(IsStr, '< type(obj) can not be compared with == >')] +copied: [HasRepr(IsStr, '< type(obj) can not be compared with == >')] + +Please fix the way your object is copied or your __eq__ implementation. +""" + ), + ) + + +def test_compare_dirty_equals_twice() -> None: + + Example( + """ +from dirty_equals import IsStr +from inline_snapshot import snapshot + +for x in 'ab': + assert x == snapshot(IsStr()) + assert [x,5] == snapshot([IsStr(),3]) + assert {'a':x,'b':5} == snapshot({'a':IsStr(),'b':3}) + +""" + ).run_inline( + ["--inline-snapshot=fix"], + changed_files=snapshot( + { + "test_something.py": """\ + +from dirty_equals import IsStr +from inline_snapshot import snapshot + +for x in 'ab': + assert x == snapshot(IsStr()) + assert [x,5] == snapshot([IsStr(),5]) + assert {'a':x,'b':5} == snapshot({'a':IsStr(),'b':5}) + +""" + } + ), + ) + + +def test_dirty_equals_in_unused_snapshot() -> None: + + Example( + """ +from dirty_equals import IsStr +from inline_snapshot import snapshot,Is + +snapshot([IsStr(),3]) +snapshot((IsStr(),3)) +snapshot({1:IsStr(),2:3}) +snapshot({1+1:2}) + +t=(1,2) +d={1:2} +l=[1,2] +snapshot([Is(t),Is(d),Is(l)]) +""" + ).run_inline( + ["--inline-snapshot=fix"], + changed_files=snapshot({}), + ) + + +def test_now_like_dirty_equals(): + # test for cases like https://github.com/15r10nk/inline-snapshot/issues/116 + + Example( + """ +from dirty_equals import DirtyEquals +from inline_snapshot import snapshot + + +def test_time(): + + now = 5 + + class Now(DirtyEquals): + def equals(self, other): + return other == now + + assert 5 == snapshot(Now()) + + now = 6 + + assert 5 == snapshot(Now()), "different time" +""" + ).run_inline( + ["--inline-snapshot=fix"], + changed_files=snapshot({}), + raises=snapshot( + """\ +AssertionError: +different time\ +""" + ), + ) + + +def test_dirty_equals_with_changing_args() -> None: + + Example( + """\ +from dirty_equals import IsInt +from inline_snapshot import snapshot + +def test_number(): + + for i in range(5): + assert ["a",i] == snapshot(["e",IsInt(gt=i-1,lt=i+1)]) + +""" + ).run_inline( + ["--inline-snapshot=fix"], + changed_files=snapshot( + { + "test_something.py": """\ +from dirty_equals import IsInt +from inline_snapshot import snapshot + +def test_number(): + + for i in range(5): + assert ["a",i] == snapshot(["a",IsInt(gt=i-1,lt=i+1)]) + +""" + } + ), + ) diff --git a/tests/test_docs.py b/tests/test_docs.py index 14a4046f..d1543afd 100644 --- a/tests/test_docs.py +++ b/tests/test_docs.py @@ -1,11 +1,233 @@ +import itertools import platform import re import sys import textwrap +from collections import defaultdict +from dataclasses import dataclass from pathlib import Path +from typing import Optional import inline_snapshot._inline_snapshot import pytest +from inline_snapshot import snapshot +from inline_snapshot.extra import raises + + +@dataclass +class Block: + code: str + code_header: Optional[str] + block_options: str + line: int + + +def map_code_blocks(file, func, fix=False): + + block_start = re.compile("( *)``` *python(.*)") + block_end = re.compile("```.*") + + header = re.compile("") + + current_code = file.read_text("utf-8") + new_lines = [] + block_lines = [] + options = set() + is_block = False + code = None + indent = "" + block_start_linenum = None + block_options = None + code_header = None + header_line = "" + + for linenumber, line in enumerate(current_code.splitlines(), start=1): + m = block_start.fullmatch(line) + if m and not is_block: + # ``` python + block_start_linenum = linenumber + indent = m[1] + block_options = m[2] + block_lines = [] + is_block = True + continue + + if block_end.fullmatch(line.strip()) and is_block: + # ``` + is_block = False + + code = "\n".join(block_lines) + "\n" + code = textwrap.dedent(code) + if file.suffix == ".py": + code = code.replace("\\\\", "\\") + + try: + new_block = func( + Block( + code=code, + code_header=code_header, + block_options=block_options, + line=block_start_linenum, + ) + ) + except Exception: + print(f"error at block at line {block_start_linenum}") + print(f"{code_header=}") + print(f"{block_options=}") + print(code) + raise + + if new_block.code_header is not None: + new_lines.append(f"{indent}") + + new_lines.append( + f"{indent}``` {('python '+new_block.block_options.strip()).strip()}" + ) + + new_code = new_block.code.rstrip() + if file.suffix == ".py": + new_code = new_code.replace("\\", "\\\\") + new_code = textwrap.indent(new_code, indent) + + new_lines.append(new_code) + + new_lines.append(f"{indent}```") + + header_line = "" + code_header = None + + continue + + if is_block: + block_lines.append(line) + continue + + m = header.fullmatch(line.strip()) + if m: + # comment + header_line = line + code_header = m[1].strip() + continue + else: + if header_line: + new_lines.append(header_line) + code_header = None + header_line = "" + + new_lines.append(line) + + new_code = "\n".join(new_lines) + "\n" + + if fix: + file.write_text(new_code) + else: + assert current_code.splitlines() == new_code.splitlines() + assert current_code == new_code + + +def test_map_code_blocks(tmp_path): + + file = tmp_path / "example.md" + + def test_doc( + markdown_code, + handle_block=lambda block: exec(block.code), + blocks=[], + exception="", + new_markdown_code=None, + ): + + file.write_text(markdown_code) + + recorded_blocks = [] + + with raises(exception): + + def test_block(block): + handle_block(block) + recorded_blocks.append(block) + return block + + map_code_blocks(file, test_block, True) + assert recorded_blocks == blocks + map_code_blocks(file, test_block, False) + + recorded_markdown_code = file.read_text() + if recorded_markdown_code != markdown_code: + assert new_markdown_code == recorded_markdown_code + else: + assert new_markdown_code == None + + test_doc( + """ +``` python +1 / 0 +``` +""", + exception=snapshot("ZeroDivisionError: division by zero"), + ) + + test_doc( + """\ +text +``` python +print(1 + 1) +``` +text + +``` python hl_lines="1 2 3" +print(1 - 1) +``` +text +""", + blocks=snapshot( + [ + Block( + code="print(1 + 1)\n", code_header=None, block_options="", line=2 + ), + Block( + code="print(1 - 1)\n", + code_header="inline-snapshot: create test", + block_options=' hl_lines="1 2 3"', + line=7, + ), + ] + ), + ) + + def change_block(block): + block.code = "# removed" + block.code_header = "header" + block.block_options = "option a b c" + + test_doc( + """\ +text +``` python +print(1 + 1) +``` +""", + handle_block=change_block, + blocks=snapshot( + [ + Block( + code="# removed", + code_header="header", + block_options="option a b c", + line=2, + ) + ] + ), + new_markdown_code=snapshot( + """\ +text + +``` python option a b c +# removed +``` +""" + ), + ) @pytest.mark.skipif( @@ -14,7 +236,7 @@ ) @pytest.mark.skipif( sys.version_info[:2] != (3, 12), - reason="\\r in stdout can cause problems in snapshot strings", + reason="there is no reason to test the doc with different python versions", ) @pytest.mark.parametrize( "file", @@ -36,19 +258,7 @@ def test_docs(project, file, subtests): * `outcome-passed=2` to check for the pytest test outcome """ - block_start = re.compile("( *)``` *python.*") - block_end = re.compile("```.*") - - header = re.compile("") - - text = file.read_text("utf-8") - new_lines = [] - block_lines = [] - options = set() - is_block = False - code = None - indent = "" - first_block = True + last_code = None project.pyproject( """ @@ -57,132 +267,104 @@ def test_docs(project, file, subtests): """ ) - for linenumber, line in enumerate(text.splitlines(), start=1): - m = block_start.fullmatch(line) - if m and is_block == True: - block_start_line = line - indent = m[1] - block_lines = [] - continue + extra_files = defaultdict(list) - if block_end.fullmatch(line.strip()) and is_block: - with subtests.test(line=linenumber): - is_block = False + def test_block(block: Block): + if block.code_header is None: + return block - last_code = code - code = "\n".join(block_lines) + "\n" - code = textwrap.dedent(code) - if file.suffix == ".py": - code = code.replace("\\\\", "\\") + if block.code_header.startswith("inline-snapshot-lib:"): + extra_files[block.code_header.split()[1]].append(block.code) + return block - flags = options & {"fix", "update", "create", "trim"} + if block.code_header.startswith("todo-inline-snapshot:"): + return block + assert False - args = ["--inline-snapshot", ",".join(flags)] if flags else [] + nonlocal last_code + with subtests.test(line=block.line): - if flags and "first_block" not in options: - project.setup(last_code) - else: - project.setup(code) + code = block.code - result = project.run(*args) + options = set(block.code_header.split()) - print("flags:", flags) + flags = options & {"fix", "update", "create", "trim"} - new_code = code - if flags: - new_code = project.source + args = ["--inline-snapshot", ",".join(flags)] if flags else [] - if "show_error" in options: - new_code = new_code.split("# Error:")[0] - new_code += "# Error:\n" + textwrap.indent( - result.errorLines(), "# " - ) + if flags and "first_block" not in options: + project.setup(last_code) + else: + project.setup(code) - print("new code:") - print(new_code) - print("expected code:") - print(code) + if extra_files: + all_files = [ + [(key, file) for file in files] + for key, files in extra_files.items() + ] + for files in itertools.product(*all_files): + for filename, content in files: + project.write_file(filename, content) + result = project.run(*args) - if ( - inline_snapshot._inline_snapshot._update_flags.fix - ): # pragma: no cover - flags_str = " ".join( - sorted(flags) - + sorted(options & {"first_block", "show_error"}) - + [ - f"outcome-{k}={v}" - for k, v in result.parseoutcomes().items() - if k in ("failed", "errors", "passed") - ] - ) - header_line = f"{indent}" + else: - new_lines.append(header_line) + result = project.run(*args) - from inline_snapshot._align import align - - linenum = 1 - hl_lines = "" - if last_code is not None and "first_block" not in options: - changed_lines = [] - alignment = align(last_code.split("\n"), new_code.split("\n")) - for c in alignment: - if c == "d": - continue - elif c == "m": - linenum += 1 - else: - changed_lines.append(str(linenum)) - linenum += 1 - if changed_lines: - hl_lines = f' hl_lines="{" ".join(changed_lines)}"' + print("flags:", flags, repr(block.block_options)) + + new_code = code + if flags: + new_code = project.source + + if "show_error" in options: + new_code = new_code.split("# Error:")[0] + new_code += "# Error:\n" + textwrap.indent(result.errorLines(), "# ") + + print("new code:") + print(new_code) + print("expected code:") + print(code) + + block.code_header = "inline-snapshot: " + " ".join( + sorted(flags) + + sorted(options & {"first_block", "show_error"}) + + [ + f"outcome-{k}={v}" + for k, v in result.parseoutcomes().items() + if k in ("failed", "errors", "passed") + ] + ) + + from inline_snapshot._align import align + + linenum = 1 + hl_lines = "" + if last_code is not None and "first_block" not in options: + changed_lines = [] + alignment = align(last_code.split("\n"), new_code.split("\n")) + for c in alignment: + if c == "d": + continue + elif c == "m": + linenum += 1 else: - assert False, "no lines changed" - - new_lines.append(f"{indent}``` python{hl_lines}") - - if ( - inline_snapshot._inline_snapshot._update_flags.fix - ): # pragma: no cover - new_code = new_code.rstrip("\n") - if file.suffix == ".py": - new_code = new_code.replace("\\", "\\\\") - new_code = textwrap.indent(new_code, indent) - - new_lines.append(new_code) + changed_lines.append(str(linenum)) + linenum += 1 + if changed_lines: + hl_lines = f'hl_lines="{" ".join(changed_lines)}"' else: - new_lines += block_lines + assert False, "no lines changed" + block.block_options = hl_lines - new_lines.append(line) + block.code = new_code - if not inline_snapshot._inline_snapshot._update_flags.fix: - if flags: - assert result.ret == 0 - else: - assert { - f"outcome-{k}={v}" - for k, v in result.parseoutcomes().items() - if k in ("failed", "errors", "passed") - } == {flag for flag in options if flag.startswith("outcome-")} - assert code == new_code - else: # pragma: no cover - pass - - continue - - m = header.fullmatch(line.strip()) - if m: - options = set(m.group(1).split()) - if first_block: - options.add("first_block") - first_block = False - header_line = line - is_block = True + if flags: + assert result.ret == 0 - if is_block: - block_lines.append(line) - else: - new_lines.append(line) + last_code = code + return block - if inline_snapshot._inline_snapshot._update_flags.fix: # pragma: no cover - file.write_text("\n".join(new_lines) + "\n", "utf-8") + map_code_blocks( + file, test_block, inline_snapshot._inline_snapshot._update_flags.fix + ) diff --git a/tests/test_inline_snapshot.py b/tests/test_inline_snapshot.py index b6cfa2bb..7b5c57ef 100644 --- a/tests/test_inline_snapshot.py +++ b/tests/test_inline_snapshot.py @@ -832,62 +832,6 @@ def test_thing(): assert result.report == snapshot("") -def test_compare_dirty_equals_twice() -> None: - - Example( - """ -from dirty_equals import IsStr -from inline_snapshot import snapshot - -for x in 'ab': - assert x == snapshot(IsStr()) - assert [x,5] == snapshot([IsStr(),3]) - assert {'a':x,'b':5} == snapshot({'a':IsStr(),'b':3}) - -""" - ).run_inline( - ["--inline-snapshot=fix"], - changed_files=snapshot( - { - "test_something.py": """\ - -from dirty_equals import IsStr -from inline_snapshot import snapshot - -for x in 'ab': - assert x == snapshot(IsStr()) - assert [x,5] == snapshot([IsStr(),5]) - assert {'a':x,'b':5} == snapshot({'a':IsStr(),'b':5}) - -""" - } - ), - ) - - -def test_dirty_equals_in_unused_snapshot() -> None: - - Example( - """ -from dirty_equals import IsStr -from inline_snapshot import snapshot - -snapshot([IsStr(),3]) -snapshot((IsStr(),3)) -snapshot({1:IsStr(),2:3}) -snapshot({1+1:2}) - -t=(1,2) -d={1:2} -l=[1,2] -snapshot([t,d,l]) -""" - ).run_inline( - ["--inline-snapshot=fix"], - changed_files=snapshot({}), - ) - - @dataclass class Warning: message: str @@ -928,7 +872,7 @@ def test_starred_warns_list(): """ from inline_snapshot import snapshot -assert [5] == snapshot([*[4]]) +assert [5] == snapshot([*[5]]) """ ).run_inline(["--inline-snapshot=fix"]) @@ -949,57 +893,49 @@ def test_starred_warns_dict(): """ from inline_snapshot import snapshot -assert {1:3} == snapshot({**{1:2}}) +assert {1:3} == snapshot({**{1:3}}) """ ).run_inline(["--inline-snapshot=fix"]) -def test_now_like_dirty_equals(): - # test for cases like https://github.com/15r10nk/inline-snapshot/issues/116 +def test_is(): Example( """ -from dirty_equals import DirtyEquals -from inline_snapshot import snapshot - - -def test_time(): +from inline_snapshot import snapshot,Is - now = 5 - - class Now(DirtyEquals): - def equals(self, other): - return other == now - - assert now == snapshot(Now()) +def test_Is(): + for i in range(3): + assert ["hello",i] == snapshot(["hi",Is(i)]) + assert ["hello",i] == snapshot({1:["hi",Is(i)]})[i] +""" + ).run_inline( + ["--inline-snapshot=create"], + changed_files=snapshot( + { + "test_something.py": """\ - now = 6 +from inline_snapshot import snapshot,Is - assert 5 == snapshot(Now()) +def test_Is(): + for i in range(3): + assert ["hello",i] == snapshot(["hi",Is(i)]) + assert ["hello",i] == snapshot({1:["hi",Is(i)], 0: ["hello", 0], 2: ["hello", 2]})[i] """ + } + ), ).run_inline( ["--inline-snapshot=fix"], changed_files=snapshot( { "test_something.py": """\ -from dirty_equals import DirtyEquals -from inline_snapshot import snapshot - - -def test_time(): - - now = 5 - - class Now(DirtyEquals): - def equals(self, other): - return other == now - - assert now == snapshot(Now()) - - now = 6 +from inline_snapshot import snapshot,Is - assert 5 == snapshot(5) +def test_Is(): + for i in range(3): + assert ["hello",i] == snapshot(["hello",Is(i)]) + assert ["hello",i] == snapshot({1:["hello",Is(i)], 0: ["hello", 0], 2: ["hello", 2]})[i] """ } ), diff --git a/tests/test_is.py b/tests/test_is.py new file mode 100644 index 00000000..cbbfa850 --- /dev/null +++ b/tests/test_is.py @@ -0,0 +1,22 @@ +from inline_snapshot._inline_snapshot import snapshot +from inline_snapshot.testing._example import Example + + +def test_missing_is(): + + Example( + """\ +from inline_snapshot import snapshot + +def test_is(): + for i in (1,2): + assert i == snapshot(i) + """ + ).run_inline( + raises=snapshot( + """\ +UsageError: +snapshot value should not change. Use Is(...) for dynamic snapshot parts.\ +""" + ) + ) diff --git a/tests/test_pydantic.py b/tests/test_pydantic.py index 0165be50..c4a16e63 100644 --- a/tests/test_pydantic.py +++ b/tests/test_pydantic.py @@ -33,7 +33,7 @@ class M(BaseModel): age:int=4 def test_pydantic(): - assert M(size=5,name="Tom")==snapshot(M(name="Tom", size=5)) + assert M(size=5,name="Tom")==snapshot(M(size=5, name="Tom")) \ """ diff --git a/tests/test_warns.py b/tests/test_warns.py new file mode 100644 index 00000000..327971de --- /dev/null +++ b/tests/test_warns.py @@ -0,0 +1,34 @@ +import warnings + +from inline_snapshot import snapshot + +from tests.warns import warns + + +def test_warns(): + + def warning(): + warnings.warn_explicit( + message="bad things happen", + category=SyntaxWarning, + filename="file.py", + lineno=5, + ) + + with warns( + snapshot([("file.py", 5, "SyntaxWarning: bad things happen")]), + include_line=True, + include_file=True, + ): + warning() + + with warns( + snapshot([("file.py", "SyntaxWarning: bad things happen")]), + include_file=True, + ): + warning() + + with warns( + snapshot(["SyntaxWarning: bad things happen"]), + ): + warning() diff --git a/tests/warns.py b/tests/warns.py new file mode 100644 index 00000000..6cf62557 --- /dev/null +++ b/tests/warns.py @@ -0,0 +1,24 @@ +import contextlib +import warnings + + +@contextlib.contextmanager +def warns(expected_warnings=[], include_line=False, include_file=False): + with warnings.catch_warnings(record=True) as result: + warnings.simplefilter("always") + yield + + def make_warning(w): + message = f"{w.category.__name__}: {w.message}" + if not include_line and not include_file: + return message + message = (message,) + + if include_line: + message = (w.lineno, *message) + if include_file: + message = (w.filename, *message) + + return message + + assert [make_warning(w) for w in result] == expected_warnings