|
| 1 | +"""Compare SDMX artefacts.""" |
| 2 | + |
| 3 | +import datetime |
| 4 | +import enum |
| 5 | +import logging |
| 6 | +import textwrap |
| 7 | +from collections import defaultdict |
| 8 | +from collections.abc import Iterable |
| 9 | +from copy import copy |
| 10 | +from dataclasses import dataclass, fields, is_dataclass |
| 11 | +from functools import singledispatch |
| 12 | +from typing import Any, TypeVar, Union |
| 13 | + |
| 14 | +import lxml.etree |
| 15 | + |
| 16 | +from . import urn |
| 17 | +from .model import internationalstring |
| 18 | + |
| 19 | +log = logging.getLogger(__name__) |
| 20 | + |
| 21 | + |
| 22 | +IGNORE_CONTEXT = {"Categorisation.artefact"} |
| 23 | +VISITED: dict[tuple[int, int], set[int]] = defaultdict(set) |
| 24 | + |
| 25 | + |
| 26 | +class Comparable: |
| 27 | + """Mix-in class for objects with a :meth:`.compare` method.""" |
| 28 | + |
| 29 | + def compare(self, other, strict: bool = True, **options) -> bool: |
| 30 | + """Return :any:`True` if `self` is the same as `other`. |
| 31 | +
|
| 32 | + `strict` and other `options` are used to construct an instance of |
| 33 | + :class:`Options`. |
| 34 | + """ |
| 35 | + return compare(self, other, Options(self, strict=strict, **options)) |
| 36 | + |
| 37 | + |
| 38 | +@dataclass |
| 39 | +class Options: |
| 40 | + """Options for a comparison.""" |
| 41 | + |
| 42 | + #: Base object for a recursive comparison. Used internally for memoization/to |
| 43 | + #: improve performance. |
| 44 | + base: Any |
| 45 | + |
| 46 | + #: Objects compare equal even if :attr:`.IdentifiableArtefact.urn` is :any:`None` |
| 47 | + #: for either or both, so long as the URNs implied by their other attributes—that |
| 48 | + #: is, returned by :func:`sdmx.urn.make`—are the same. |
| 49 | + allow_implied_urn: bool = True |
| 50 | + |
| 51 | + #: Strict comparison: if :any:`True` (the default), then attributes and associated |
| 52 | + #: objects must compare exactly equal. If :any:`False`, then :any:`None` values on |
| 53 | + #: either side are permitted. |
| 54 | + strict: bool = True |
| 55 | + |
| 56 | + #: Level for log messages. |
| 57 | + log_level: int = logging.NOTSET |
| 58 | + |
| 59 | + #: Verbose comparison: continue comparing even after reaching a definitive |
| 60 | + #: :any:`False` result. If :attr:`log_level` is not set, :py:`verbose = True` |
| 61 | + #: implies :py:`log_level = logging.DEBUG`. |
| 62 | + verbose: bool = False |
| 63 | + |
| 64 | + _memo_key: tuple[int, int] = (0, 0) |
| 65 | + |
| 66 | + def __post_init__(self) -> None: |
| 67 | + # Create a key for memoization |
| 68 | + self._memo_key = (id(self.base), id(self)) |
| 69 | + VISITED[self._memo_key].clear() |
| 70 | + |
| 71 | + # If no log level is given, set a default based on verbose |
| 72 | + if self.log_level == logging.NOTSET: |
| 73 | + self.log_level = {True: logging.DEBUG, False: logging.INFO}[self.verbose] |
| 74 | + |
| 75 | + def log(self, message: str, level: int = logging.INFO) -> None: |
| 76 | + """Log `message` on `level`. |
| 77 | +
|
| 78 | + `level` must be at least :attr:`log_level`. |
| 79 | + """ |
| 80 | + if level >= self.log_level: |
| 81 | + log.log(level, message) |
| 82 | + |
| 83 | + def visited(self, obj) -> bool: |
| 84 | + """Return :any:`True` if `obj` has already be compared.""" |
| 85 | + if type(obj).__module__ == "builtins": |
| 86 | + return False |
| 87 | + |
| 88 | + entry = id(obj) |
| 89 | + |
| 90 | + if entry in VISITED[self._memo_key]: |
| 91 | + return True |
| 92 | + else: |
| 93 | + VISITED[self._memo_key].add(entry) |
| 94 | + return False |
| 95 | + |
| 96 | + |
| 97 | +T = TypeVar("T", bound=object) |
| 98 | + |
| 99 | + |
| 100 | +@singledispatch |
| 101 | +def compare(left: object, right, opts: Options, context: str = "") -> bool: |
| 102 | + """Compare `left` to `right`.""" |
| 103 | + if is_dataclass(left): |
| 104 | + return compare_dataclass(left, right, opts, context) |
| 105 | + |
| 106 | + raise NotImplementedError(f"Compare {type(left)} {left!r} in {context}") |
| 107 | + |
| 108 | + |
| 109 | +def compare_dataclass(left, right, opts: Options, context: str) -> bool: |
| 110 | + c = context or type(left).__name__ |
| 111 | + |
| 112 | + result = right is not None |
| 113 | + for f in fields(left) if result else []: |
| 114 | + l_val, r_val = getattr(left, f.name), getattr(right, f.name) |
| 115 | + |
| 116 | + if opts.visited(l_val): |
| 117 | + continue # Already compared to its counterpart |
| 118 | + |
| 119 | + c_sub = f"{c}.{f.name}" |
| 120 | + |
| 121 | + # Handle Options.allow_implied_urn |
| 122 | + if f.name == "urn" and not l_val is r_val is None and opts.allow_implied_urn: |
| 123 | + try: |
| 124 | + l_val = l_val or urn.make(left) |
| 125 | + except (AttributeError, ValueError): |
| 126 | + pass |
| 127 | + try: |
| 128 | + r_val = r_val or urn.make(right) |
| 129 | + except (AttributeError, ValueError): |
| 130 | + pass |
| 131 | + |
| 132 | + result_f = ( |
| 133 | + l_val is r_val |
| 134 | + or compare(l_val, r_val, opts, c_sub) |
| 135 | + or c_sub in IGNORE_CONTEXT |
| 136 | + ) |
| 137 | + |
| 138 | + result &= result_f |
| 139 | + |
| 140 | + if result_f is False: |
| 141 | + opts.log(f"Not identical: {c_sub}={shorten(l_val)} != {shorten(r_val)}") |
| 142 | + if not opts.verbose: |
| 143 | + break |
| 144 | + else: |
| 145 | + opts.log(f"{c_sub}={shorten(l_val)} == {shorten(r_val)}", logging.DEBUG) |
| 146 | + |
| 147 | + return result |
| 148 | + |
| 149 | + |
| 150 | +# Built-in types |
| 151 | + |
| 152 | + |
| 153 | +# TODO When dropping support for Python <=3.10, change to '@compare.register' |
| 154 | +@compare.register(int) |
| 155 | +@compare.register(str) |
| 156 | +@compare.register(datetime.date) |
| 157 | +def _eq(left: Union[int, str, datetime.date], right, opts, context=""): |
| 158 | + """Built-in types that must compare equal.""" |
| 159 | + return left == right or (not opts.strict and right is None) |
| 160 | + |
| 161 | + |
| 162 | +# TODO When dropping support for Python <=3.10, change to '@compare.register' |
| 163 | +@compare.register(type(None)) |
| 164 | +@compare.register(bool) |
| 165 | +@compare.register(float) |
| 166 | +@compare.register(type) |
| 167 | +@compare.register(enum.Enum) |
| 168 | +def _is(left: Union[None, bool, float, type, enum.Enum], right, opts, context): |
| 169 | + """Built-in types that must compare identical.""" |
| 170 | + return left is right or (not opts.strict and right is None or left is None) |
| 171 | + |
| 172 | + |
| 173 | +@compare.register |
| 174 | +def _(left: dict, right, opts, context=""): |
| 175 | + """Return :obj:`True` if `self` is the same as `other`. |
| 176 | +
|
| 177 | + Two DictLike instances are identical if they contain the same set of keys, and |
| 178 | + corresponding values compare equal. |
| 179 | + """ |
| 180 | + result = True |
| 181 | + |
| 182 | + l_keys = set(left.keys()) |
| 183 | + r_keys = set(right.keys()) if hasattr(right, "keys") else set() |
| 184 | + if l_keys != r_keys: |
| 185 | + opts.log( |
| 186 | + f"Mismatched {type(left).__name__} keys: {shorten(sorted(l_keys))} " |
| 187 | + f"!= {shorten(sorted(r_keys))}" |
| 188 | + ) |
| 189 | + result = False |
| 190 | + |
| 191 | + # Compare items pairwise |
| 192 | + for key in sorted(l_keys) if (result or opts.verbose and right is not None) else (): |
| 193 | + result &= compare(left[key], right.get(key, None), opts) |
| 194 | + if result is False and not opts.verbose: |
| 195 | + break |
| 196 | + |
| 197 | + return result |
| 198 | + |
| 199 | + |
| 200 | +# TODO When dropping support for Python <=3.10, change to '@compare.register' |
| 201 | +@compare.register(list) |
| 202 | +@compare.register(set) |
| 203 | +def _(left: Union[list, set], right, opts, context=""): |
| 204 | + if len(left) != len(right): |
| 205 | + opts.log(f"Mismatched length: {len(left)} != {len(right)}") |
| 206 | + return False |
| 207 | + |
| 208 | + try: |
| 209 | + l_values: Iterable = sorted(left) |
| 210 | + r_values: Iterable = sorted(right) |
| 211 | + except TypeError: |
| 212 | + l_values, r_values = left, right |
| 213 | + |
| 214 | + return all( |
| 215 | + compare(a, b, opts, f"{context}[{i}]") |
| 216 | + for i, (a, b) in enumerate(zip(l_values, r_values)) |
| 217 | + ) |
| 218 | + |
| 219 | + |
| 220 | +# Types from upstream packages |
| 221 | + |
| 222 | + |
| 223 | +@compare.register |
| 224 | +def _(left: lxml.etree._Element, right, opts, context=""): |
| 225 | + try: |
| 226 | + r_val = copy(right) |
| 227 | + lxml.etree.cleanup_namespaces(r_val) |
| 228 | + except TypeError: |
| 229 | + return not opts.strict |
| 230 | + else: |
| 231 | + l_val = copy(left) |
| 232 | + lxml.etree.cleanup_namespaces(l_val) |
| 233 | + return lxml.etree.tostring(l_val) == lxml.etree.tostring(r_val) |
| 234 | + |
| 235 | + |
| 236 | +# SDMX types |
| 237 | + |
| 238 | + |
| 239 | +@compare.register |
| 240 | +def _(left: internationalstring.InternationalString, right, opts, context=""): |
| 241 | + return compare( |
| 242 | + left.localizations, right.localizations, opts, f"{context}.localizations" |
| 243 | + ) |
| 244 | + |
| 245 | + |
| 246 | +def shorten(value: Any) -> str: |
| 247 | + """Return a shortened :func:`repr` of `value` for logging.""" |
| 248 | + return textwrap.shorten(repr(value), 30, placeholder="…") |
0 commit comments