Skip to content

(WIP) 2580 improve runtimes but pushing up common case statements into precomputed values #2630

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation

RobinL
Copy link
Member

@RobinL RobinL commented Feb 19, 2025

Very much work in progress for now, but the overall approach here seems to work

See #2580

Closes #2580

@RobinL
Copy link
Member Author

RobinL commented Feb 19, 2025

import splink.comparison_library as cl
import pandas as pd
from splink import DuckDBAPI, Linker, SettingsCreator, block_on, splink_datasets


db_api = DuckDBAPI()

df = splink_datasets.fake_1000
df = pd.concat([df]*5)
df.reset_index(drop=True, inplace=True)
df["unique_id"] = df.index

thresholds = list([n / 100 for n in range(100, 20, -10)])
lev_thresholds = list(range(1, 10, 1))

settings = SettingsCreator(
    link_type="dedupe_only",
    comparisons=[
        cl.JaroWinklerAtThresholds("first_name", thresholds),
        cl.JaccardAtThresholds("surname", thresholds),
        cl.LevenshteinAtThresholds("email", lev_thresholds),
    ],
    blocking_rules_to_generate_predictions=[
        "1=1",
    ],
    max_iterations=2,
)

linker = Linker(df, settings, db_api)

pairwise_predictions = linker.inference.predict(
    threshold_match_weight=-10, experimental_optimisation=True
)

# Display all columns from prediction dataframe where first_name matches
sql = f"""
SELECT sum(match_probability)
FROM {pairwise_predictions.physical_name}
"""

display(linker.misc.query_sql(sql))


db_api2 = DuckDBAPI()
linker2 = Linker(df, settings, db_api2)

pairwise_predictions2 = linker2.inference.predict(
    threshold_match_weight=-10, experimental_optimisation=False
)
sql = f"""
SELECT sum(match_probability)
FROM {pairwise_predictions2.physical_name}
"""

display(linker2.misc.query_sql(sql))

Existing code:
Blocking time: 0.26 seconds
Predict time: 4.45 seconds

With speedup:
Blocking time: 0.26 seconds
Predict time: 0.68 seconds

@RobinL
Copy link
Member Author

RobinL commented Feb 19, 2025

I think the structure of this is a bit wrong - passing the experimental_optimisation flag isn't quite right, I think this should 'belong' to the comparison

And we want to perform the optimisation within the comparison so we already have access to the sql dialect

But for now i'm just going to get it working

Also remember this current method doensn't immediately work with compare_two_records()

@RobinL
Copy link
Member Author

RobinL commented Feb 19, 2025

Rembee to add some tests based on the context.md file

@RobinL
Copy link
Member Author

RobinL commented Feb 23, 2025

The following fails on sqlglot 25.x but works on 26.6.0

The error is something to do with XOR at q^0.33 i.e. it's not correctly translating/transpiling the ^

(
    list_reduce(
        list_prepend(
            1.0,
            list_transform(
                token_rel_freq_arr_l,
                x -> CASE
                        WHEN array_contains(
                            list_transform(token_rel_freq_arr_r, y -> y.tok),
                            x.tok
                        )
                        THEN x.rel_freq
                        ELSE 1.0
                    END
            )
        ),
        (p, q) -> p * q
    ) * 
    list_reduce(
        list_prepend(
            1.0,
            list_transform(
                list_concat(
                    array_filter(
                        token_rel_freq_arr_l,
                        y -> NOT array_contains(
                                list_transform(token_rel_freq_arr_r, x -> x.tok),
                                y.tok
                            )
                    ),
                    array_filter(
                        token_rel_freq_arr_r,
                        y -> NOT array_contains(
                                list_transform(token_rel_freq_arr_l, x -> x.tok),
                                y.tok
                            )
                    )
                ),
                x -> x.rel_freq
            )
        ),
        (p, q) -> p [/](https://file+.vscode-resource.vscode-cdn.net/) q^0.33
    ))

@RobinL
Copy link
Member Author

RobinL commented Jul 4, 2025

Quickprompt is

Details

Consider the following example:

select
CASE
WHEN jaro_winkler_similarity("first_name_l", "first_name_r") >= 0.9 THEN 4
WHEN jaro_winkler_similarity("first_name_l", "first_name_r") >= 0.8 THEN 3
WHEN jaro_winkler_similarity("first_name_l", "first_name_r") >= 0.7 THEN 2
WHEN jaro_winkler_similarity("first_name_l", "first_name_r") >= 0.6 THEN 1
else 0
end as gamma_first_name
from df

The following query does the same thing, but is about 4x faster

with reusable as (
    select jaro_winkler_similarity("first_name_l", "first_name_r") as jws
    from df
)
select
CASE
WHEN jws >= 0.9 THEN 4
WHEN jws >= 0.8 THEN 3
WHEN jws >= 0.7 THEN 2
WHEN jws >= 0.6 THEN 1
else 0
end as gamma_first_name
from reusable


I have a variety of sql statements that I want to optimise.  Here are some examples:

Example 1:

select
CASE
WHEN jaro_winkler_similarity("first_name_l", "first_name_r") >= 0.9 THEN 4
WHEN jaro_winkler_similarity("first_name_l", "first_name_r") >= 0.8 THEN 3
WHEN jaro_winkler_similarity("first_name_l", "first_name_r") >= 0.7 THEN 2
WHEN jaro_winkler_similarity("first_name_l", "first_name_r") >= 0.6 THEN 1
else 0
end as gamma_first_name

the reusable part here is jaro_winkler_similarity("first_name_l", "first_name_r")

Example 2:

CASE
WHEN "first_name_l" IS NULL OR "first_name_r" IS NULL THEN -1
WHEN array_length(list_intersect("first_name_l", "first_name_r")) > 4 THEN 4
WHEN array_length(list_intersect("first_name_l", "first_name_r")) > 3 THEN 3
WHEN array_length(list_intersect("first_name_l", "first_name_r")) > 2 THEN 2
WHEN array_length(list_intersect("first_name_l", "first_name_r")) > 1 THEN 1 ELSE 0 END as gamma_first_name

the reusable part here is WHEN array_length(list_intersect("first_name_l", "first_name_r"))

Example 3:

CASE
WHEN try_strptime("dob_l", '%Y-%m-%d') IS NULL OR try_strptime("dob_r", '%Y-%m-%d') IS NULL THEN -1
WHEN "dob_l" = "dob_r" THEN 4
WHEN ABS(EPOCH(try_strptime("dob_l", '%Y-%m-%d')) - EPOCH(try_strptime("dob_r", '%Y-%m-%d'))) <= 86400 THEN 3
WHEN ABS(EPOCH(try_strptime("dob_l", '%Y-%m-%d')) - EPOCH(try_strptime("dob_r", '%Y-%m-%d'))) <= 2629800.0 THEN 2
WHEN ABS(EPOCH(try_strptime("dob_l", '%Y-%m-%d')) - EPOCH(try_strptime("dob_r", '%Y-%m-%d'))) <= 31557600.0 THEN 1 ELSE 0 END as gamma_dob

The reusable part here is
WHEN ABS(EPOCH(try_strptime("dob_l", '%Y-%m-%d')) - EPOCH(try_strptime("dob_r", '%Y-%m-%d')))

Example 4:
CASE
WHEN "name_tokens_with_freq_l" IS NULL OR "name_tokens_with_freq_r" IS NULL THEN -1
WHEN
                    list_reduce((p, q) -> p * q, list_concat(1.0::FLOAT, list_transform(list_intersect(email_l, email_r),x -> x.rel_freq::float))) < 1e-12
                     THEN 5
WHEN
                    list_reduce((p, q) -> p * q, list_concat(1.0::FLOAT, list_transform(list_intersect(email_l, email_r),x -> x.rel_freq::float))) < 1e-10
                     THEN 4
WHEN
                    list_reduce((p, q) -> p * q, list_concat(1.0::FLOAT, list_transform(list_intersect(email_l, email_r),x -> x.rel_freq::float))) < 1e-8
                     THEN 3
WHEN
                    list_reduce((p, q) -> p * q, list_concat(1.0::FLOAT, list_transform(list_intersect(email_l, email_r),x -> x.rel_freq::float))) < 1e-6
                     THEN 2
WHEN
                    list_reduce((p, q) -> p * q, list_concat(1.0::FLOAT, list_transform(list_intersect(email_l, email_r),x -> x.rel_freq::float))) < 1e-4
                     THEN 1 ELSE 0 END as gamma_name_tokens_with_freq

the reusable part here is list_reduce((p, q) -> p * q, list_concat(1.0::FLOAT, list_transform(list_intersect(email_l, email_r),x -> x.rel_freq::float)))



I'm working in Python.  I want a elegant, high quality approach to finding and extract the reusable part, and replacing it with a variable

detaield prompt is

Details
# Context for Optimizing Repeated Function Calls in CASE Statements

## Overview

Splink's performance can suffer because within the `CASE` statements (used to compute comparison vector values) the same function (e.g. `jaro_winkler_similarity`, `array_length(list_intersect(...))`, etc.) is often computed multiple times for different thresholds. The goal of this project is to refactor such SQL so that any repeated function call is computed only once and stored as a reusable column. For example, instead of:

```sql
SELECT
    CASE
    WHEN "first_name_l" IS NULL OR "first_name_r" IS NULL THEN -1
    WHEN "first_name_l" = "first_name_r" THEN 3
    WHEN jaro_winkler_similarity("first_name_l", "first_name_r") >= 0.9 THEN 2
    WHEN jaro_winkler_similarity("first_name_l", "first_name_r") >= 0.7 THEN 1
    ELSE 0
    END as gamma_first_name
FROM blocked_with_cols

we want to generate something like:

WITH __splink__reusable_function_values AS (
    SELECT
        *,
        jaro_winkler_similarity("first_name_l", "first_name_r") AS jws_first_name
    FROM blocked_with_cols
)
SELECT
    CASE
    WHEN "first_name_l" IS NULL OR "first_name_r" IS NULL THEN -1
    WHEN "first_name_l" = "first_name_r" THEN 3
    WHEN jws_first_name >= 0.9 THEN 2
    WHEN jws_first_name >= 0.7 THEN 1
    ELSE 0
    END as gamma_first_name
FROM __splink__reusable_function_values

To implement, a possible approach is to use SQLGlot to:

  • Parse the SQL AST.
  • Identify any repeated function expressions (by counting occurrences).
  • Replace repeated calls with a variable/alias.
  • Build a CTE that computes all such reusable expressions once.
  • Reassemble the query so that downstream processes use the precomputed values.

How This Fits into Splink

The CASE statement in which these repeated function calls occur is produced by the Comparison class. The optimized approach is intended to be generic so that it automatically finds any repeated function calls within the SQL (and later, similar expressions in comparison levels).

Main Files Affected

  • @comparison.py
    Contains the Comparison class. In particular, the _case_statement property (around line 220) builds the CASE statement for each comparison. The output from this method is used downstream to compute the comparison vector values and is called in:

    • _columns_to_select_for_comparison_vector_values
    • Other methods in building the final SELECT clause.
  • @comparison_level.py
    Defines the ComparisonLevel class. Each level produces a piece of the CASE statement (via _when_then_comparison_vector_value_sql). Optimizations must ensure that any repeated function call from these SQL fragments is factored out.

  • @predict.py
    In functions such as predict_from_comparison_vectors_sqls_using_settings, the CASE statement (along with other SQL fragments) is combined into one large SQL query that is sent to the DB. The optimized SQL (with the extra CTE for reused columns) needs to be correctly integrated into this process.

  • @inference.py
    The LinkerInference class (and its predict method) orchestrates the prediction pipeline. This call stack eventually uses the SQL generated from the Comparison objects and must handle the modified SQL query when the optimization is applied.

  • @linker.py
    The Linker class ties together the entire workflow (training, inference, clustering, etc.). This is where the new optimized SQL generation will finally be invoked as part of the overall pipeline.

Call Stack Example

In a typical prediction run the following events occur:

  1. Comparison._case_statement (in comparison.py) builds a CASE clause using SQL fragments provided by each ComparisonLevel.
  2. The method _columns_to_select_for_comparison_vector_values (in both comparison.py and settings.py) collects these CASE statements amongst other columns.
  3. In predict.py, these columns are combined into a SELECT statement which is executed against the database.
  4. Finally, LinkerInference.predict (in inference.py) and Linker.predict (in linker.py) trigger the entire pipeline.

Your optimization code (using SQLGlot) will modify the SQL at step (1) so that any repeated function call is replaced by a reusable variable, and an appropriate CTE is added to precompute these values.

Implementation Highlights

  • SQLGlot-based Transformations:
    A two-pass approach is applied:

    • First pass: Count all function calls and mark those that are repeated. Replace the repeated instances (ensuring that nested repeated calls are not double-replaced) with a placeholder literal (based on a cleaned version of the function's SQL representation).
    • Second pass: Convert these literal placeholders into identifier nodes referencing the computed alias.
  • Building the CTE:
    A new CTE is constructed that selects the original table (aliased as t) along with new computed columns corresponding to the extracted repeated function calls. This CTE is then injected into the main query (replacing the original FROM clause with one referencing the CTE).

  • Testing and Debugging:
    Several test cases (e.g., for jaro_winkler_similarity, array_length(list_intersect(...)), ABS(EPOCH(try_strptime(...))), and list_reduce(...)) have been defined in the standalone script to ensure the transformation works. In case of failure, detailed SQL or AST tree logs will help with debugging.

Next Steps for Integration

  • Modify/extend Comparison._case_statement (in @comparison.py) so that if a flag (e.g. has_optimised_sql_condition) is enabled, the SQL produced goes through the optimization routine.
  • Ensure that the transformation is transparent to the rest of the code. The final SQL must maintain all correct aliases and output column names.
  • Validate the end-to-end pipeline by running tests within the Splink framework using the updated prediction methods in @predict.py and @inference.py.
  • Verify that the additional CTE does not interfere with downstream components in @linker.py.

Questions & Verification

Before fully integrating, consider running these checks:

  • Does the optimized SQL produce the same result as the original while improving performance?
  • Are there any dialect-specific issues (the dialect used is DuckDB)?
  • How do the new temporary variable names (generated from cleaned function SQL strings) interact with user-defined aliases?
  • What additional debug output would be most helpful to diagnose any discrepancies between the optimized and original SQL ASTs?

SQL Pipeline Details

The optimization needs to be integrated into a specific part of Splink's SQL generation pipeline. Here's the detailed flow:

  1. The comparison vector values are computed in multiple stages through SQL CTEs:
-- Stage 1: Join blocked pairs with input data
WITH blocked_with_cols AS (
    SELECT
        l.unique_id AS unique_id_l,
        r.unique_id AS unique_id_r,
        l.first_name AS first_name_l,
        r.first_name AS first_name_r,
        -- ... other columns
    FROM __splink__df_concat_with_tf AS l
    INNER JOIN __splink__blocked_id_pairs AS b
        ON l.unique_id = b.join_key_l
    INNER JOIN __splink__df_concat_with_tf AS r
        ON r.unique_id = b.join_key_r
),

-- Stage 2 (NEW): Compute reusable function values
__splink__reusable_function_values AS (
    SELECT
        *,
        jaro_winkler_similarity(first_name_l, first_name_r) AS jws_first_name,
        jaccard(surname_l, surname_r) AS jd_surname
        -- ... other reusable computations
    FROM blocked_with_cols
),

-- Stage 3: Compute comparison vectors using reusable values
__splink__df_comparison_vectors AS (
    SELECT
        unique_id_l,
        unique_id_r,
        first_name_l,
        first_name_r,
        CASE
            WHEN first_name_l IS NULL OR first_name_r IS NULL THEN -1
            WHEN first_name_l = first_name_r THEN 3
            WHEN jws_first_name >= 0.9 THEN 2
            WHEN jws_first_name >= 0.7 THEN 1
            ELSE 0
        END as gamma_first_name,
        -- ... other comparison vectors
    FROM __splink__reusable_function_values
)

Key Integration Points

  1. Comparison._case_statement

    • Currently returns a single CASE statement as a string
    • Needs to be modified to optionally return both:
      • The reusable function computations for the CTE
      • The modified CASE statement using the computed values
    • A new property like _reusable_expressions could store the extracted function calls
  2. compute_comparison_vector_values_from_id_pairs_sqls

    • Currently returns a list of two SQL statements:
      1. Join blocked pairs with input data (blocked_with_cols)
      2. Compute comparison vectors (__splink__df_comparison_vectors)
    • Should be modified to insert a new intermediate CTE for reusable computations
    • The list of SQLs will become:
      [
          {
              'sql': '... join blocked pairs ...',
              'output_table_name': 'blocked_with_cols'
          },
          {
              'sql': '... compute reusable values ...',
              'output_table_name': '__splink__reusable_function_values'
          },
          {
              'sql': '... compute comparison vectors using reusable values ...',
              'output_table_name': '__splink__df_comparison_vectors'
          }
      ]
  3. Settings Integration

    • The list _columns_to_select_for_comparison_vector_values contains both regular columns and CASE statements
    • The optimization needs to:
      1. Extract reusable functions from these CASE statements
      2. Add new columns for the reusable computations
      3. Modify the CASE statements to use the computed values

Implementation Strategy

  1. Create a new class/module to handle the SQL optimization:

    class SQLOptimizer:
        def __init__(self, sql, dialect="duckdb"):
            self.sql = sql
            self.dialect = dialect
    
        def find_reusable_expressions(self):
            # Use SQLGlot to find repeated function calls
    
        def generate_cte_sql(self):
            # Generate SQL for computing reusable values
    
        def generate_modified_case_sql(self):
            # Generate modified CASE using computed values
  2. Modify Comparison class:

    class Comparison:
        def __init__(self):
            self.has_optimised_sql_condition = False
            self._reusable_expressions = {}
    
        @property
        def _case_statement(self):
            if not self.has_optimised_sql_condition:
                return self._original_case_statement()
    
            return self._optimised_case_statement()
  3. Add configuration options to enable/disable optimization:

    • Global flag in settings
    • Per-comparison flag
    • Minimum repetition threshold for optimization

Testing Strategy

  1. Unit tests for SQL transformation:

    • Test each example case (jaro_winkler, array_length, etc.)
    • Verify SQL equivalence
    • Check performance improvement
  2. Integration tests:

    • End-to-end comparison of results
    • Verify all comparison types still work
    • Test with different dialects
  3. Performance benchmarks:

    • Measure impact on query execution time
    • Monitor memory usage
    • Test with different data sizes

Detailed SQL Pipeline Structure

The complete SQL pipeline for prediction consists of several CTEs:

  1. __splink__df_concat_with_tf:

    • Initial table creation with input data
    • Includes random salt for sampling
    • Contains raw input columns
  2. __splink__blocked_id_pairs:

    • Contains three columns: match_key, join_key_l, join_key_r
    • Created from blocking rules
    • Filters using l.unique_id < r.unique_id for dedupe
  3. blocked_with_cols:

    • Joins blocked pairs back to input data
    • Creates _l and _r versions of each column
    • Includes match_key from blocking
  4. __splink__df_comparison_vectors:

    • Contains the CASE statements we want to optimize
    • Each CASE produces a gamma_* column
    • This is where our optimization will focus
  5. __splink__df_match_weight_parts:

    • Converts gamma values to Bayes factors
    • Creates bf_* columns
  6. Final SELECT:

    • Computes match_weight and match_probability
    • Applies threshold filtering
    • Returns final result columns

CASE Statement AST Structure

The CASE statements we need to optimize have a specific structure in SQLGlot:

Alias(
  this=Case(
    ifs=[
      If(this=<condition>, true=<value>),
      If(this=<condition>, true=<value>),
      ...
    ],
    default=<else_value>
  ),
  alias=Identifier(this="gamma_*")
)

Key points about the AST:

  • Function calls appear in the this field of If nodes
  • The same function call can appear in multiple If nodes
  • The alias must be preserved as gamma_* in the output
  • Some conditions (like IS NULL) should not be considered for optimization

Implementation Details

  1. The optimization should be inserted between the blocked_with_cols and __splink__df_comparison_vectors CTEs.

  2. Column naming conventions:

    • Input columns use _l and _r suffixes
    • Gamma columns must be named gamma_*
    • Reusable function columns should use a consistent prefix/pattern
  3. Performance considerations:

    • The optimization runs once per CASE statement
    • Each CASE statement belongs to a Comparison object
    • Multiple Comparisons can be present in a single query

SQLGlot Implementation Patterns

Based on the working solution in o3_mini_high.py, here's how to effectively use SQLGlot for SQL transformation:

  1. Two-Pass Transformation Pattern

    def optimize_sql(sql: str) -> str:
        # First pass: Find repeated functions and replace with placeholders
        repeated_funcs = count_repeated_functions(ast)
        transformed_ast = replace_repeated_functions(ast, repeated_funcs, var_mapping)
    
        # Second pass: Convert placeholders to column references
        main_ast = restore_literals_to_identifiers(transformed_ast, var_mapping)
  2. Finding Repeated Functions

    def count_repeated_functions(ast: exp.Expression, dialect: str = "duckdb") -> Set[str]:
        function_counts: Dict[str, int] = {}
        for func in ast.find_all(exp.Func):
            func_sql = func.sql(dialect=dialect)
            function_counts[func_sql] = function_counts.get(func_sql, 0) + 1
        return {f for f, count in function_counts.items() if count > 1}
  3. Variable Name Generation

    def get_variable_name(func_sql: str) -> str:
        """Generate a sanitized variable name from the function SQL string."""
        return "".join(c if c.isalnum() else "_" for c in func_sql.lower())
  4. Handling Nested Functions

    # Skip if nested in another repeated function (prefer outermost)
    if (node.parent and
        isinstance(node.parent, exp.Func) and
        node.parent.sql(dialect=dialect) in repeated_funcs):
        return node
  5. Building CTEs

    def build_reusable_cte(table: exp.Table, var_mapping: Dict[str, exp.Expression]) -> exp.CTE:
        computed_columns = []
        for var, func_expr in var_mapping.items():
            qualified_expr = qualify_columns(func_expr, table_alias)
            computed_columns.append(
                exp.Alias(this=qualified_expr, alias=exp.Identifier(this=var))
            )
    
        # Build CTE with star plus computed columns
        cte_select = exp.Select(
            expressions=[exp.Star()] + computed_columns,
            from_=table_alias
        )
        return exp.CTE(
            this=cte_select,
            alias=exp.TableAlias(this=exp.Identifier(this="reusable"))
        )

Key SQLGlot Patterns

  1. AST Navigation

    • Use find_all() to locate specific node types
    • Check node.parent for context
    • Use transform() for tree modification
  2. SQL Generation

    • Use sql(dialect="duckdb") to generate SQL strings
    • Always specify dialect to ensure consistent output
    • Use pretty=True for readable output
  3. Node Construction

    • Create nodes using exp.* constructors
    • Build complex expressions bottom-up
    • Use copy() to duplicate nodes safely
  4. AST Manipulation Best Practices

    # Transform pattern
    def transform_func(node):
        if isinstance(node, exp.Func):
            # Modify node
            return new_node
        return node
    
    new_ast = ast.transform(transform_func)

Applying to Splink

  1. Integration Points

    class Comparison:
        def _optimised_case_statement(self):
            # Get original SQL
            sql = self._original_case_statement()
    
            # Parse and optimize
            ast = sqlglot.parse_one(sql, read="duckdb")
            optimizer = SQLOptimizer(ast)
    
            # Return both parts
            return {
                "reusable_expressions": optimizer.get_reusable_sql(),
                "case_statement": optimizer.get_modified_case_sql()
            }
  2. Error Handling

    try:
        ast = sqlglot.parse_one(sql, read="duckdb")
    except sqlglot.ParseError as e:
        logger.error(f"Failed to parse SQL: {sql}")
        raise
  3. Testing Patterns

    def compare_query_results(original_sql: str, optimized_sql: str):
        """Execute both queries and verify results match"""
        orig_res = execute_sql(original_sql)
        opt_res = execute_sql(optimized_sql)
        assert orig_res == opt_res, "Results differ!"

Common Pitfalls

  1. AST Modification

    • Always create new nodes rather than modifying existing ones
    • Use copy() when storing nodes for later use
    • Be careful with parent references when moving nodes
  2. SQL Generation

    • Always specify dialect when generating SQL
    • Verify column names and aliases are preserved
    • Handle NULL values and edge cases
  3. Performance

    • Cache repeated SQL generation calls
    • Minimize tree traversals
    • Use sets for lookups

This document provides all the details needed to incorporate the optimization into Splink. Please ask any further questions or run additional tests as needed to help iterate or improve this integration.

Additional Context and Clarifications

Actual SQL Structure

The CASE statements we need to optimize appear in the __splink__df_comparison_vectors CTE. For example:

CASE
WHEN "first_name_l" IS NULL OR "first_name_r" IS NULL THEN -1
WHEN "first_name_l" = "first_name_r" THEN 3
WHEN jaro_winkler_similarity("first_name_l", "first_name_r") >= 0.9 THEN 2
WHEN jaro_winkler_similarity("first_name_l", "first_name_r") >= 0.7 THEN 1
ELSE 0
END as gamma_first_name

Integration Points

The CASE statement is generated in Comparison._case_statement. This property:

  1. Is called by _columns_to_select_for_comparison_vector_values
  2. Returns a string containing the CASE statement
  3. Currently has no optimization logic

SQLGlot AST Structure

The parsed AST for these CASE statements has this structure:

Alias(
  this=Case(
    ifs=[
      If(this=<condition>, true=<value>),
      If(this=<condition>, true=<value>),
      ...
    ],
    default=Literal(this=0, is_string=False)
  ),
  alias=Identifier(this="gamma_*")
)

CTE Pipeline Structure

The full SQL query has this CTE structure:

  1. blocked_with_cols: Joins blocked pairs with input data
  2. __splink__df_comparison_vectors: Contains the CASE statements we need to optimize
  3. __splink__df_match_weight_parts: Converts gamma values to Bayes factors
  4. Final SELECT: Computes match weights and probabilities

Our optimization should be inserted between blocked_with_cols and __splink__df_comparison_vectors.

Key Requirements

  1. Must preserve the gamma_* column names and values
  2. Must handle NULL checks appropriately (they should remain as-is)
  3. Must work with DuckDB dialect
  4. Should not modify the structure of other CTEs
  5. Must handle multiple comparisons in the same query (e.g. first_name AND surname)

ComparisonLevel Details

Each CASE statement is built from multiple ComparisonLevel objects, which:

  1. Define individual WHEN/THEN conditions
  2. Can include NULL checks, exact matches, and function-based comparisons
  3. Are combined in the Comparison._case_statement property

Testing Considerations

The sandbox code provides a good test case with:

  • Two comparisons (first_name and surname)
  • Different functions (jaro_winkler_similarity and jaccard)
  • Multiple thresholds per comparison
  • NULL handling

The sandbox is:

import splink.comparison_library as cl
from splink import DuckDBAPI, Linker, SettingsCreator, block_on, splink_datasets

db_api = DuckDBAPI()

df = splink_datasets.fake_1000

settings = SettingsCreator(
    link_type="dedupe_only",
    comparisons=[
        cl.JaroWinklerAtThresholds("first_name"),
        cl.JaccardAtThresholds("surname"),
    ],
    blocking_rules_to_generate_predictions=[
        block_on("first_name"),
    ],
    max_iterations=2,
)

linker = Linker(df, settings, db_api)
import logging

logging.basicConfig(format="%(message)s")
logging.getLogger("splink").setLevel(1)

pairwise_predictions = linker.inference.predict(threshold_match_weight=-10)

import sqlglot


sql = (
    settings.get_settings("duckdb")
    .comparisons[0]
    ._case_statement.replace("WHEN", "\nWHEN")
)

print(sql)
sqlglot.parse_one(sql)

and results in:

SQL pipeline was passed inputs [] and output dataset __splink__df_concat_with_tf
    Pipeline part 1: CTE reads tables [ __splink__input_table_0 ] and has output table name: __splink__df_concat
    Pipeline part 2: CTE reads tables [ __splink__df_concat ] and has output table name: __splink__df_concat_with_tf
Executing sql to create __splink__df_concat_with_tf as physical name __splink__df_concat_with_tf_86c7e29e5

------Start SQL---------
CREATE TABLE __splink__df_concat_with_tf_86c7e29e5 AS
WITH

__splink__df_concat as (
            select "unique_id", "first_name", "surname", "dob", "city", "email", "cluster"
            , random() as __splink_salt
            from __splink__input_table_0
            )
select * from __splink__df_concat
-------End SQL-----------

Setting cache for __splink__df_concat_with_tf_86c7e29e5 with physical name __splink__df_concat_with_tf_86c7e29e5
Setting cache for __splink__df_concat_with_tf with physical name __splink__df_concat_with_tf_86c7e29e5
SQL pipeline was passed inputs [__splink__df_concat_with_tf_86c7e29e5] and output dataset __splink__blocked_id_pairs
    Pipeline part 1: CTE reads tables [ __splink__df_concat_with_tf_86c7e29e5 ] and has output table name: __splink__df_concat_with_tf
    Pipeline part 2: CTE reads tables [ __splink__df_concat_with_tf AS r, __splink__df_concat_with_tf AS l ] and has output table name: __splink__blocked_id_pairs
Executing sql to create __splink__blocked_id_pairs as physical name __splink__blocked_id_pairs_4142a51b5

------Start SQL---------
CREATE TABLE __splink__blocked_id_pairs_4142a51b5 AS
WITH

__splink__df_concat_with_tf as (
select * from __splink__df_concat_with_tf_86c7e29e5)

            select
            '0' as match_key,
            l."unique_id" as join_key_l,
            r."unique_id" as join_key_r
            from __splink__df_concat_with_tf as l
            inner join __splink__df_concat_with_tf as r
            on
            (l."first_name" = r."first_name")
            where l."unique_id" < r."unique_id"


-------End SQL-----------

Setting cache for __splink__blocked_id_pairs_4142a51b5 with physical name __splink__blocked_id_pairs_4142a51b5
Blocking time: 0.00 seconds
SQL pipeline was passed inputs [__splink__blocked_id_pairs_4142a51b5, __splink__df_concat_with_tf_86c7e29e5] and output dataset __splink__df_predict
    Pipeline part 1: CTE reads tables [ __splink__blocked_id_pairs_4142a51b5 ] and has output table name: __splink__blocked_id_pairs
    Pipeline part 2: CTE reads tables [ __splink__df_concat_with_tf_86c7e29e5 ] and has output table name: __splink__df_concat_with_tf
    Pipeline part 3: CTE reads tables [ __splink__df_concat_with_tf AS r, __splink__blocked_id_pairs AS b, __splink__df_concat_with_tf AS l ] and has output table name: blocked_with_cols
    Pipeline part 4: CTE reads tables [ blocked_with_cols ] and has output table name: __splink__df_comparison_vectors
    Pipeline part 5: CTE reads tables [ __splink__df_comparison_vectors ] and has output table name: __splink__df_match_weight_parts
    Pipeline part 6: CTE reads tables [ __splink__df_match_weight_parts ] and has output table name: __splink__df_predict
Executing sql to create __splink__df_predict as physical name __splink__df_predict_09540489c

------Start SQL---------
CREATE TABLE __splink__df_predict_09540489c AS
WITH

__splink__blocked_id_pairs as (
select * from __splink__blocked_id_pairs_4142a51b5),

__splink__df_concat_with_tf as (
select * from __splink__df_concat_with_tf_86c7e29e5),

blocked_with_cols as (
    select "l"."unique_id" AS "unique_id_l",
"r"."unique_id" AS "unique_id_r",
"l"."first_name" AS "first_name_l",
"r"."first_name" AS "first_name_r",
"l"."surname" AS "surname_l",
"r"."surname" AS "surname_r", b.match_key
    from __splink__df_concat_with_tf as l
    inner join __splink__blocked_id_pairs as b
    on l."unique_id" = b.join_key_l
    inner join __splink__df_concat_with_tf as r
    on r."unique_id" = b.join_key_r
    ),

__splink__df_comparison_vectors as (
    select "unique_id_l",
"unique_id_r",
"first_name_l",
"first_name_r",
CASE WHEN "first_name_l" IS NULL OR "first_name_r" IS NULL THEN -1 WHEN "first_name_l" = "first_name_r" THEN 3 WHEN jaro_winkler_similarity("first_name_l", "first_name_r") >= 0.9 THEN 2 WHEN jaro_winkler_similarity("first_name_l", "first_name_r") >= 0.7 THEN 1 ELSE 0 END as gamma_first_name,
"surname_l",
"surname_r",
CASE WHEN "surname_l" IS NULL OR "surname_r" IS NULL THEN -1 WHEN "surname_l" = "surname_r" THEN 3 WHEN jaccard("surname_l", "surname_r") >= 0.9 THEN 2 WHEN jaccard("surname_l", "surname_r") >= 0.7 THEN 1 ELSE 0 END as gamma_surname
    from blocked_with_cols
    ),

__splink__df_match_weight_parts as (
    select "unique_id_l","unique_id_r","first_name_l","first_name_r",gamma_first_name,CASE
WHEN
gamma_first_name = -1
THEN cast(1.0 as float8)

WHEN
gamma_first_name = 3
THEN cast(1024.0 as float8)

WHEN
gamma_first_name = 2
THEN cast(8.0 as float8)

WHEN
gamma_first_name = 1
THEN cast(0.5 as float8)

WHEN
gamma_first_name = 0
THEN cast(0.03125 as float8)
 END as bf_first_name ,"surname_l","surname_r",gamma_surname,CASE
WHEN
gamma_surname = -1
THEN cast(1.0 as float8)

WHEN
gamma_surname = 3
THEN cast(1024.0 as float8)

WHEN
gamma_surname = 2
THEN cast(8.0 as float8)

WHEN
gamma_surname = 1
THEN cast(0.5 as float8)

WHEN
gamma_surname = 0
THEN cast(0.03125 as float8)
 END as bf_surname
    from __splink__df_comparison_vectors
    )

    select
    log2(cast(0.00010001000100010001 as float8) * bf_first_name * bf_surname) as match_weight,
    CASE WHEN bf_first_name = cast('infinity' as float8) OR bf_surname = cast('infinity' as float8) THEN 1.0 ELSE (cast(0.00010001000100010001 as float8) * bf_first_name * bf_surname)/(1+(cast(0.00010001000100010001 as float8) * bf_first_name * bf_surname)) END as match_probability,
    "unique_id_l","unique_id_r","first_name_l","first_name_r",gamma_first_name,"surname_l","surname_r",gamma_surname
    from __splink__df_match_weight_parts
     where log2(cast(0.00010001000100010001 as float8) * bf_first_name * bf_surname) >= -10

-------End SQL-----------

Setting cache for __splink__df_predict_09540489c with physical name __splink__df_predict_09540489c
Predict time: 0.03 seconds

 -- WARNING --
You have called predict(), but there are some parameter estimates which have neither been estimated or specified in your settings dictionary.  To produce predictions the following untrained trained parameters will use default values.
Comparison: 'first_name':
    m values not fully trained
Comparison: 'first_name':
    u values not fully trained
Comparison: 'surname':
    m values not fully trained
Comparison: 'surname':
    u values not fully trained
The 'probability_two_random_records_match' setting has been set to the default value (0.0001).
If this is not the desired behaviour, either:
 - assign a value for `probability_two_random_records_match` in your settings dictionary, or
 - estimate with the `linker.estimate_probability_two_random_records_match` function.
Dropping table with templated name __splink__blocked_id_pairs and physical name __splink__blocked_id_pairs_4142a51b5
CASE
WHEN "first_name_l" IS NULL OR "first_name_r" IS NULL THEN -1
WHEN "first_name_l" = "first_name_r" THEN 3
WHEN jaro_winkler_similarity("first_name_l", "first_name_r") >= 0.9 THEN 2
WHEN jaro_winkler_similarity("first_name_l", "first_name_r") >= 0.7 THEN 1 ELSE 0 END as gamma_first_name
Alias(
  this=Case(
    ifs=[
      If(
        this=Or(
          this=Is(
            this=Column(
              this=Identifier(this=first_name_l, quoted=True)),
            expression=Null()),
          expression=Is(
            this=Column(
              this=Identifier(this=first_name_r, quoted=True)),
            expression=Null())),
        true=Neg(
          this=Literal(this=1, is_string=False))),
      If(
        this=EQ(
          this=Column(
            this=Identifier(this=first_name_l, quoted=True)),
          expression=Column(
            this=Identifier(this=first_name_r, quoted=True))),
        true=Literal(this=3, is_string=False)),
      If(
        this=GTE(
          this=Anonymous(
            this=jaro_winkler_similarity,
            expressions=[
              Column(
                this=Identifier(this=first_name_l, quoted=True)),
              Column(
                this=Identifier(this=first_name_r, quoted=True))]),
          expression=Literal(this=0.9, is_string=False)),
        true=Literal(this=2, is_string=False)),
      If(
        this=GTE(
          this=Anonymous(
            this=jaro_winkler_similarity,
            expressions=[
              Column(
                this=Identifier(this=first_name_l, quoted=True)),
              Column(
                this=Identifier(this=first_name_r, quoted=True))]),
          expression=Literal(this=0.7, is_string=False)),
        true=Literal(this=1, is_string=False))],
    default=Literal(this=0, is_string=False)),
  alias=Identifier(this=gamma_first_name, quoted=False))

Let's eventually make this an argument on predict() for now i.e. predict(experimental_optimisation=True)



</details> 

@RobinL
Copy link
Member Author

RobinL commented Jul 4, 2025

I got gemini to analyse the whole sqlglot codebase and it thinks these are the relevant files for the purpose of solving this task

mkdir -p sqlglot_refactoring_task && rsync -aR \
  ./posts/ast_primer.md \
  ./posts/onboarding.md \
  ./README.md \
  ./sqlglot/__init__.py \
  ./sqlglot/expressions.py \
  ./sqlglot/parser.py \
  ./sqlglot/generator.py \
  ./sqlglot/tokens.py \
  ./sqlglot/helper.py \
  ./sqlglot/dialects/dialect.py \
  ./sqlglot/optimizer/optimizer.py \
  ./sqlglot/optimizer/eliminate_subqueries.py \
  ./sqlglot/optimizer/scope.py \
  ./sqlglot/optimizer/qualify.py \
  sqlglot_refactoring_task/

Test script:

Details

from splink.internals.reusable_function_detection import (
    _find_repeated_functions,
    _build_reusable_functions_sql,
)

sql = """CASE
WHEN list_reduce(list_prepend(1.0, list_transform(list_intersect(name_tokens_with_freq_l, name_tokens_with_freq_r), x -> CAST(x.rel_freq AS FLOAT))), (p, q) -> p * q) < 1e-12 then 2
WHEN list_reduce(list_prepend(1.0, list_transform(list_intersect(name_tokens_with_freq_l, name_tokens_with_freq_r), x -> CAST(x.rel_freq AS FLOAT))), (p, q) -> p * q) < 1e-12 then 1
ELSE 0
END
"""


columns = [sql]

repeated_functions, modified_columns = _find_repeated_functions(columns, "duckdb")

for func in repeated_functions:
    print(func["function_sql"])
    print(func["alias"])

print("--")
for m in modified_columns:
    print(m.replace("WHEN", "\nWHEN"))

@RobinL
Copy link
Member Author

RobinL commented Jul 4, 2025

Is this a better implementation?

Details

from __future__ import annotations

from collections import Counter
from typing import Dict, List, Set, Tuple

from sqlglot import exp, parse_one


# --------------------------------------------------------------------------- #
# 1. Helpers
# --------------------------------------------------------------------------- #


def _func_key(node: exp.Func) -> Tuple:
    """
    A hashable structural key for a function expression.

    sqlglot.Expression already exposes `.key` which is
    `(type(node), tuple_of_child_keys, ...)` and is stable across formatting.
    """
    return node.sql()


def _gather_counts(asts: List[exp.Expression]) -> Counter:
    """Count every Func node across all supplied ASTs."""
    counts: Counter = Counter()
    for ast in asts:
        for func in ast.find_all(exp.Func):
            counts[_func_key(func)] += 1
    return counts


def _suppress_nested(duplicated: Set[Tuple], asts: List[exp.Expression]) -> Set[Tuple]:
    """
    Remove keys that are *strictly* nested inside another duplicated function key.
    """
    keep: Set[Tuple] = set(duplicated)
    for ast in asts:
        for func in ast.find_all(exp.Func):
            k = _func_key(func)
            if k not in duplicated:
                continue
            parent = func.parent
            while parent:
                if isinstance(parent, exp.Func) and _func_key(parent) in duplicated:
                    keep.discard(k)  # child is nested => don’t treat as reusable
                    break
                parent = parent.parent
    return keep


# --------------------------------------------------------------------------- #
# 2. Public API (replacements for your originals)
# --------------------------------------------------------------------------- #


def _find_repeated_functions(
    columns_to_select_for_comparison_vector_values: List[str],
    sqlglot_dialect: str,
) -> Tuple[List[Dict[str, str]], List[str]]:
    """
    Detect repeated function calls in CASE expressions, replace them with
    column references, and return (metadata, modified_columns).

    The return contract matches your original helper exactly.
    """
    # ── 1. parse CASE expressions ────────────────────────────────────────────
    case_sqls = [
        col
        for col in columns_to_select_for_comparison_vector_values
        if col.lstrip().upper().startswith("CASE")
    ]
    asts = [parse_one(sql, read=sqlglot_dialect) for sql in case_sqls]

    # ── 2. count & filter duplicates ─────────────────────────────────────────
    counts = _gather_counts(asts)
    duplicated = {k for k, c in counts.items() if c > 1}
    duplicated = _suppress_nested(duplicated, asts)

    if not duplicated:
        # nothing to do – early-out to avoid a second walk
        return [], columns_to_select_for_comparison_vector_values[:]

    # ── 3. allocate stable aliases & build lookup tables ─────────────────────
    var_mapping: Dict[Tuple, str] = {}
    repeated_functions: List[Dict[str, str]] = []

    # pick a deterministic order so test runs are reproducible
    for idx, key in enumerate(sorted(duplicated, key=hash), 1):
        # locate *one* concrete Func node to turn back into SQL for the CTE
        expr = next(
            func
            for ast in asts
            for func in ast.find_all(exp.Func)
            if _func_key(func) == key
        )
        alias = f"rf_{idx}"
        var_mapping[key] = alias
        repeated_functions.append(
            {"function_sql": expr.sql(dialect=sqlglot_dialect), "alias": alias}
        )

    # ── 4. transform original SQL list ───────────────────────────────────────
    def _replace(node: exp.Expression):
        if isinstance(node, exp.Func) and _func_key(node) in var_mapping:
            # but leave it untouched if its *ancestor* will become a column
            parent = node.parent
            while parent:
                if isinstance(parent, exp.Func) and _func_key(parent) in var_mapping:
                    return node  # nested → do not replace
                parent = parent.parent
            return exp.to_identifier(var_mapping[_func_key(node)])
        return node

    modified_columns: List[str] = []
    for original in columns_to_select_for_comparison_vector_values:
        if original.lstrip().upper().startswith("CASE"):
            mod_ast = parse_one(original, read=sqlglot_dialect).transform(_replace)
            modified_columns.append(mod_ast.sql(dialect=sqlglot_dialect))
        else:
            modified_columns.append(original)

    return repeated_functions, modified_columns


def _build_reusable_functions_sql(repeated_functions: List[Dict[str, str]]) -> str:
    """Generate the SELECT …, <func> AS <alias> layer."""
    if not repeated_functions:
        return "SELECT * FROM blocked_with_cols"

    computed = ",\n           ".join(
        f"{f['function_sql']} AS {f['alias']}" for f in repeated_functions
    )
    return f"""
    SELECT *,
           {computed}
    FROM blocked_with_cols
    """.strip()

Ohwait if hash is usable, then we can use eq meaning we can use a different appraoch

Details

# reusable_function_detection.py
from __future__ import annotations

from collections import Counter
from typing import Dict, List, Tuple

from sqlglot import exp, parse_one


def _find_repeated_functions(
    columns_to_select_for_comparison_vector_values: List[str],
    sqlglot_dialect: str,
) -> Tuple[List[Dict[str, str]], List[str]]:
    """
    Detect function sub-expressions that are used more than once inside CASE
    expressions and rewrite the columns so they refer to computed aliases.

    Returns
    -------
    repeated_functions : list of {"function_sql": str, "alias": str}
        One entry per *root* duplicate function, in deterministic order.
    modified_columns   : list[str]
        The input columns, with duplicates replaced by their alias.
    """
    # --- 1. Parse only the CASE expressions ---------------------------------
    case_asts = [
        parse_one(col, read=sqlglot_dialect)
        for col in columns_to_select_for_comparison_vector_values
        if col.lstrip().upper().startswith("CASE")
    ]

    # --- 2. Count every function node across all CASEs ----------------------
    func_counts: Counter[exp.Expression] = Counter()
    for ast in case_asts:
        func_counts.update(fn for fn in ast.find_all(exp.Func))

    repeated: set[exp.Expression] = {fn for fn, c in func_counts.items() if c > 1}

    # --- 3. Keep only the *root* duplicates (not nested inside another dup) --
    def is_nested_in_repeated(fn: exp.Func) -> bool:
        parent = fn.parent
        while parent:
            if isinstance(parent, exp.Func) and parent in repeated:
                return True
            parent = parent.parent
        return False

    roots_in_order: list[exp.Func] = []
    seen: set[exp.Expression] = set()  # protect against re-adding same struct
    for ast in case_asts:
        for fn in ast.find_all(exp.Func):
            if fn in repeated and not is_nested_in_repeated(fn) and fn not in seen:
                roots_in_order.append(fn)
                seen.add(fn)

    # --- 4. Give each root an alias -----------------------------------------
    var_mapping: Dict[exp.Expression, str] = {}
    repeated_functions: list[dict[str, str]] = []

    for idx, fn in enumerate(roots_in_order, start=1):
        alias = f"rf_{idx}"
        var_mapping[fn] = alias
        repeated_functions.append(
            {
                "function_sql": fn.sql(dialect=sqlglot_dialect),
                "alias": alias,
            }
        )

    # --- 5. Replace in every CASE AST ---------------------------------------
    def _replace(node: exp.Expression) -> exp.Expression:
        # Only root duplicates are in var_mapping
        if isinstance(node, exp.Func) and node in var_mapping:
            return exp.to_identifier(var_mapping[node])
        return node

    modified_columns: list[str] = []
    for col in columns_to_select_for_comparison_vector_values:
        if col.lstrip().upper().startswith("CASE"):
            ast = parse_one(col, read=sqlglot_dialect)
            modified_columns.append(
                ast.transform(_replace).sql(dialect=sqlglot_dialect)
            )
        else:
            modified_columns.append(col)

    return repeated_functions, modified_columns


def _build_reusable_functions_sql(repeated_functions: List[Dict[str, str]]) -> str:
    """
    Build a CTE that adds all reusable columns to `blocked_with_cols`.
    """
    if not repeated_functions:
        return "SELECT * FROM blocked_with_cols"

    computed_cols = ",\n           ".join(
        f"{f['function_sql']} AS {f['alias']}" for f in repeated_functions
    )
    return f"""SELECT *,
           {computed_cols}
    FROM blocked_with_cols"""

@RobinL
Copy link
Member Author

RobinL commented Jul 27, 2025

see #2738

@RobinL RobinL closed this Jul 27, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Improve runtimes by 'pushing up' common Case Statements into precomputed values
1 participant