Skip to content

Commit 26d7148

Browse files
Merge pull request #1642 from moj-analytical-services/update_sqlglot
Update sqlglot to >=13.0.0
2 parents 7489d75 + 23b17fe commit 26d7148

7 files changed

+104
-69
lines changed

poetry.lock

+6-6
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

+1-2
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,7 @@ jsonschema = ">=3.2,<5.0"
1414
# 1.3.5 is the last version supporting py 3.7.1
1515
pandas = ">1.3.0"
1616
duckdb = ">=0.8.0"
17-
# normalize issue in sqlglot - temporarily exclude updates
18-
sqlglot = ">=7.0.0,<11.4.2"
17+
sqlglot = ">=13.0.0, <19.0.0"
1918
altair = "^5.0.1"
2019
Jinja2 = ">=3.0.3"
2120
phonetics = "^1.0.5"

splink/comparison_level.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import sqlglot
1111
from sqlglot.expressions import Identifier
1212
from sqlglot.optimizer.normalize import normalize
13+
from sqlglot.optimizer.simplify import simplify
1314

1415
from .constants import LEVEL_NOT_OBSERVED_TEXT
1516
from .default_from_jsonschema import default_value_from_schema
@@ -495,7 +496,7 @@ def _is_exact_match(self):
495496
sql_syntax_tree = sqlglot.parse_one(
496497
self.sql_condition.lower(), read=self.sql_dialect
497498
)
498-
sql_cnf = normalize(sql_syntax_tree)
499+
sql_cnf = simplify(normalize(sql_syntax_tree))
499500

500501
exprs = _get_and_subclauses(sql_cnf)
501502
for expr in exprs:
@@ -508,7 +509,7 @@ def _exact_match_colnames(self):
508509
sql_syntax_tree = sqlglot.parse_one(
509510
self.sql_condition.lower(), read=self.sql_dialect
510511
)
511-
sql_cnf = normalize(sql_syntax_tree)
512+
sql_cnf = simplify(normalize(sql_syntax_tree))
512513

513514
exprs = _get_and_subclauses(sql_cnf)
514515
for expr in exprs:

splink/input_column.py

+3-6
Original file line numberDiff line numberDiff line change
@@ -227,14 +227,11 @@ def _get_dialect_quotes(dialect):
227227

228228

229229
def _get_sqlglot_dialect_quotes(dialect: sqlglot.Dialect):
230-
# TODO: once we drop support for sqlglot < 6.0.0, we can simplify this
231230
try:
232-
# For sqlglot < 6.0.0
233-
quotes = dialect.identifiers
234-
quote = '"' if '"' in quotes else quotes[0]
235-
start = end = quote
231+
# For sqlglot >= 16.0.0
232+
start = dialect.IDENTIFIER_START
233+
end = dialect.IDENTIFIER_END
236234
except AttributeError:
237-
# For sqlglot >= 6.0.0
238235
start = dialect.identifier_start
239236
end = dialect.identifier_end
240237
return start, end

tests/test_comparison_level.py

+85
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
from pytest import mark, raises
2+
3+
from splink.comparison_level import ComparisonLevel
4+
5+
from .decorator import mark_with_dialects_excluding
6+
7+
8+
def make_comparison_level(sql_condition, dialect):
9+
return ComparisonLevel(
10+
{
11+
"sql_condition": sql_condition,
12+
"label_for_charts": "nice_informative_label",
13+
},
14+
sql_dialect=dialect,
15+
)
16+
17+
18+
# SQL conditions that are of 'exact match' type
19+
exact_matchy_sql_conditions_and_columns = [
20+
("col_l = col_r", {"col"}),
21+
("col_l = col_r AND another_col_l = another_col_r", {"col", "another_col"}),
22+
(
23+
"col_l = col_r AND another_col_l = another_col_r AND third_l = third_r",
24+
{"col", "another_col", "third"},
25+
),
26+
(
27+
"(col_l = col_r AND another_col_l = another_col_r) AND third_l = third_r",
28+
{"col", "another_col", "third"},
29+
),
30+
(
31+
"col_l = col_r AND (another_col_l = another_col_r AND third_l = third_r)",
32+
{"col", "another_col", "third"},
33+
),
34+
]
35+
36+
37+
@mark.parametrize(
38+
"sql_condition, exact_match_cols", exact_matchy_sql_conditions_and_columns
39+
)
40+
@mark_with_dialects_excluding()
41+
def test_is_exact_match_for_exact_matchy_levels(
42+
sql_condition, exact_match_cols, dialect
43+
):
44+
lev = make_comparison_level(sql_condition, dialect)
45+
assert lev._is_exact_match
46+
47+
48+
@mark.parametrize(
49+
"sql_condition, exact_match_cols", exact_matchy_sql_conditions_and_columns
50+
)
51+
@mark_with_dialects_excluding()
52+
def test_exact_match_colnames_for_exact_matchy_levels(
53+
sql_condition, exact_match_cols, dialect
54+
):
55+
lev = make_comparison_level(sql_condition, dialect)
56+
assert set(lev._exact_match_colnames) == exact_match_cols
57+
58+
59+
# SQL conditions that are NOT of 'exact match' type
60+
non_exact_matchy_sql_conditions = [
61+
"levenshtein(col_l, col_r) < 3",
62+
"col_l < col_r",
63+
"col_l = col_r OR another_col_l = another_col_r",
64+
"col_l = a_different_col_r",
65+
"col_l = col_r AND (col_2_l = col_2_r OR col_3_l = col_3_r)",
66+
"col_l = col_r AND (col_2_l < col_2_r)",
67+
"substr(col_l, 2) = substr(col_r, 2)",
68+
]
69+
70+
71+
@mark.parametrize("sql_condition", non_exact_matchy_sql_conditions)
72+
@mark_with_dialects_excluding()
73+
def test_is_exact_match_for_non_exact_matchy_levels(sql_condition, dialect):
74+
lev = make_comparison_level(sql_condition, dialect)
75+
assert not lev._is_exact_match
76+
77+
78+
@mark.parametrize("sql_condition", non_exact_matchy_sql_conditions)
79+
@mark_with_dialects_excluding()
80+
def test_exact_match_colnames_for_non_exact_matchy_levels(sql_condition, dialect):
81+
lev = make_comparison_level(sql_condition, dialect)
82+
# _exact_match_colnames should have an error if it is
83+
# not actually an exact match level
84+
with raises(ValueError):
85+
lev._exact_match_colnames

tests/test_compound_comparison_levels.py

-51
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
11
import pandas as pd
2-
from sqlglot import parse_one
3-
from sqlglot.optimizer.normalize import normalize
42

53
import splink.duckdb.comparison_level_library as cll
64
import splink.duckdb.comparison_library as cl
@@ -216,52 +214,3 @@ def test_complex_compound_comparison_level():
216214
linker = DuckDBLinker(df, settings)
217215

218216
linker.estimate_parameters_using_expectation_maximisation("1=1")
219-
220-
221-
def test_normalise():
222-
# check that the sqlglot normaliser is doing what we think
223-
# try to not impose specific form too strongly, so we aren't too tightly
224-
# coupled to the implementationß
225-
sql_syntax_tree = parse_one("a or (b and c)")
226-
sql_cnf = normalize(sql_syntax_tree).sql().lower()
227-
228-
subclauses_expected = [
229-
["a or c", "c or a"],
230-
["a or b", "b or a"],
231-
]
232-
233-
# get subclauses and remove outer parens
234-
subclauses_found = map(lambda s: s.strip("()"), sql_cnf.split(" and "))
235-
236-
# loop through subclauses, make sure that we have exactly one of each
237-
for found in subclauses_found:
238-
term_found = False
239-
for i, expected in enumerate(subclauses_expected):
240-
if found in expected:
241-
del subclauses_expected[i]
242-
term_found = True
243-
break
244-
assert term_found, f"CNF contains unexpected clause '{found}'"
245-
assert not subclauses_expected
246-
247-
# and a slightly more complex statement
248-
sql_syntax_tree = parse_one("(a and b) or (a and c) or (c and d) or (d and b)")
249-
sql_cnf = normalize(sql_syntax_tree).sql().lower()
250-
251-
subclauses_expected = [
252-
["b or c", "c or b"],
253-
["a or d", "d or a"],
254-
]
255-
256-
subclauses_found = map(lambda s: s.strip("()"), sql_cnf.split(" and "))
257-
258-
# loop through subclauses, make sure that we have exactly one of each
259-
for found in subclauses_found:
260-
term_found = False
261-
for i, expected in enumerate(subclauses_expected):
262-
if found in expected:
263-
del subclauses_expected[i]
264-
term_found = True
265-
break
266-
assert term_found, f"CNF contains unexpected clause '{found}'"
267-
assert not subclauses_expected

tests/test_sql_transform.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,16 @@ def test_move_l_r_table_prefix_to_column_suffix():
3636
move_l_r_test(br, expected)
3737

3838
br = "len(list_filter(l.name_list, x -> list_contains(r.name_list, x))) >= 1"
39-
expected = "len(list_filter(name_list_l, x -> list_contains(name_list_r, x))) >= 1"
39+
expected = (
40+
"length(list_filter(name_list_l, x -> list_contains(name_list_r, x))) >= 1"
41+
)
4042
move_l_r_test(br, expected)
4143

4244
br = "len(list_filter(l.name_list, x -> list_contains(r.name_list, x))) >= 1"
4345
res = move_l_r_table_prefix_to_column_suffix(br)
44-
expected = "len(list_filter(name_list_l, x -> list_contains(name_list_r, x))) >= 1"
46+
expected = (
47+
"length(list_filter(name_list_l, x -> list_contains(name_list_r, x))) >= 1"
48+
)
4549
assert res.lower() == expected.lower()
4650

4751

0 commit comments

Comments
 (0)