-
Notifications
You must be signed in to change notification settings - Fork 205
Introduce Symbolic Constraint Solver for SQL-Driven Data Generation #564
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Half way through the README.md. Will continue reading and then proceed to the code.
How does the system in general handle expressions where the values depend on each other.
Eg.
SELECT * FROM test.suitcase WHERE width + height + length < 25
Does this need a new domain type?
| /** | ||
| * Copyright 2025 LinkedIn Corporation. All rights reserved. | ||
| * Licensed under the BSD-2 Clause license. | ||
| * See LICENSE in the project root for license information. | ||
| */ | ||
| package com.linkedin.coral.datagen.domain; | ||
|
|
||
| import java.util.Arrays; | ||
| import java.util.List; | ||
|
|
||
| import org.testng.annotations.Test; | ||
|
|
||
|
|
||
| /** | ||
| * Tests for IntegerDomain class. | ||
| */ | ||
| public class IntegerDomainTest { | ||
|
|
||
| @Test | ||
| public void testSingleValue() { | ||
| System.out.println("\n=== Single Value Test ==="); | ||
| IntegerDomain domain = IntegerDomain.of(42); | ||
| System.out.println("Domain: " + domain); | ||
| System.out.println("Is empty: " + domain.isEmpty()); | ||
| System.out.println("Contains 42: " + domain.contains(42)); | ||
| System.out.println("Contains 43: " + domain.contains(43)); | ||
| System.out.println("Samples: " + domain.sampleValues(5)); | ||
| } | ||
|
|
||
| @Test | ||
| public void testSingleInterval() { | ||
| System.out.println("\n=== Single Interval Test ==="); | ||
| IntegerDomain domain = IntegerDomain.of(10, 20); | ||
| System.out.println("Domain: " + domain); | ||
| System.out.println("Contains 10: " + domain.contains(10)); | ||
| System.out.println("Contains 15: " + domain.contains(15)); | ||
| System.out.println("Contains 20: " + domain.contains(20)); | ||
| System.out.println("Contains 21: " + domain.contains(21)); | ||
| System.out.println("Samples: " + domain.sampleValues(5)); | ||
| } | ||
|
|
||
| @Test | ||
| public void testMultipleIntervals() { | ||
| System.out.println("\n=== Multiple Intervals Test ==="); | ||
| List<IntegerDomain.Interval> intervals = Arrays.asList(new IntegerDomain.Interval(1, 5), | ||
| new IntegerDomain.Interval(10, 15), new IntegerDomain.Interval(20, 30)); | ||
| IntegerDomain domain = IntegerDomain.of(intervals); | ||
| System.out.println("Domain: " + domain); | ||
| System.out.println("Contains 3: " + domain.contains(3)); | ||
| System.out.println("Contains 7: " + domain.contains(7)); | ||
| System.out.println("Contains 12: " + domain.contains(12)); | ||
| System.out.println("Contains 25: " + domain.contains(25)); | ||
| System.out.println("Samples: " + domain.sampleValues(10)); | ||
| } | ||
|
|
||
| @Test | ||
| public void testIntersection() { | ||
| System.out.println("\n=== Intersection Test ==="); | ||
| IntegerDomain domain1 = IntegerDomain.of(1, 20); | ||
| IntegerDomain domain2 = IntegerDomain.of(10, 30); | ||
| IntegerDomain intersection = domain1.intersect(domain2); | ||
| System.out.println("Domain 1: " + domain1); | ||
| System.out.println("Domain 2: " + domain2); | ||
| System.out.println("Intersection: " + intersection); | ||
| System.out.println("Samples: " + intersection.sampleValues(5)); | ||
| } | ||
|
|
||
| @Test | ||
| public void testUnion() { | ||
| System.out.println("\n=== Union Test ==="); | ||
| IntegerDomain domain1 = IntegerDomain.of(1, 10); | ||
| IntegerDomain domain2 = IntegerDomain.of(20, 30); | ||
| IntegerDomain union = domain1.union(domain2); | ||
| System.out.println("Domain 1: " + domain1); | ||
| System.out.println("Domain 2: " + domain2); | ||
| System.out.println("Union: " + union); | ||
| System.out.println("Samples: " + union.sampleValues(10)); | ||
| } | ||
|
|
||
| @Test | ||
| public void testAddConstant() { | ||
| System.out.println("\n=== Add Constant Test ==="); | ||
| IntegerDomain domain = IntegerDomain.of(10, 20); | ||
| IntegerDomain shifted = domain.add(5); | ||
| System.out.println("Original domain: " + domain); | ||
| System.out.println("After adding 5: " + shifted); | ||
| System.out.println("Samples: " + shifted.sampleValues(5)); | ||
| } | ||
|
|
||
| @Test | ||
| public void testMultiplyConstant() { | ||
| System.out.println("\n=== Multiply Constant Test ==="); | ||
| IntegerDomain domain = IntegerDomain.of(10, 20); | ||
| IntegerDomain scaled = domain.multiply(2); | ||
| System.out.println("Original domain: " + domain); | ||
| System.out.println("After multiplying by 2: " + scaled); | ||
| System.out.println("Samples: " + scaled.sampleValues(5)); | ||
| } | ||
|
|
||
| @Test | ||
| public void testNegativeMultiply() { | ||
| System.out.println("\n=== Negative Multiply Test ==="); | ||
| IntegerDomain domain = IntegerDomain.of(10, 20); | ||
| IntegerDomain scaled = domain.multiply(-1); | ||
| System.out.println("Original domain: " + domain); | ||
| System.out.println("After multiplying by -1: " + scaled); | ||
| System.out.println("Samples: " + scaled.sampleValues(5)); | ||
| } | ||
|
|
||
| @Test | ||
| public void testOverlappingIntervalsMerge() { | ||
| System.out.println("\n=== Overlapping Intervals Merge Test ==="); | ||
| List<IntegerDomain.Interval> intervals = Arrays.asList(new IntegerDomain.Interval(1, 10), | ||
| new IntegerDomain.Interval(5, 15), new IntegerDomain.Interval(20, 30)); | ||
| IntegerDomain domain = IntegerDomain.of(intervals); | ||
| System.out.println("Input intervals: [1, 10], [5, 15], [20, 30]"); | ||
| System.out.println("Merged domain: " + domain); | ||
| System.out.println("Samples: " + domain.sampleValues(10)); | ||
| } | ||
|
|
||
| @Test | ||
| public void testAdjacentIntervalsMerge() { | ||
| System.out.println("\n=== Adjacent Intervals Merge Test ==="); | ||
| List<IntegerDomain.Interval> intervals = Arrays.asList(new IntegerDomain.Interval(1, 10), | ||
| new IntegerDomain.Interval(11, 20), new IntegerDomain.Interval(30, 40)); | ||
| IntegerDomain domain = IntegerDomain.of(intervals); | ||
| System.out.println("Input intervals: [1, 10], [11, 20], [30, 40]"); | ||
| System.out.println("Merged domain: " + domain); | ||
| System.out.println("Samples: " + domain.sampleValues(10)); | ||
| } | ||
|
|
||
| @Test | ||
| public void testEmptyDomain() { | ||
| System.out.println("\n=== Empty Domain Test ==="); | ||
| IntegerDomain empty = IntegerDomain.empty(); | ||
| System.out.println("Empty domain: " + empty); | ||
| System.out.println("Is empty: " + empty.isEmpty()); | ||
| System.out.println("Samples: " + empty.sampleValues(5)); | ||
| } | ||
|
|
||
| @Test | ||
| public void testIntersectionEmpty() { | ||
| System.out.println("\n=== Intersection Empty Test ==="); | ||
| IntegerDomain domain1 = IntegerDomain.of(1, 10); | ||
| IntegerDomain domain2 = IntegerDomain.of(20, 30); | ||
| IntegerDomain intersection = domain1.intersect(domain2); | ||
| System.out.println("Domain 1: " + domain1); | ||
| System.out.println("Domain 2: " + domain2); | ||
| System.out.println("Intersection: " + intersection); | ||
| System.out.println("Is empty: " + intersection.isEmpty()); | ||
| } | ||
|
|
||
| @Test | ||
| public void testComplexArithmetic() { | ||
| System.out.println("\n=== Complex Arithmetic Test ==="); | ||
| // Solve: 2*x + 5 = 25, where x in [0, 100] | ||
| // => 2*x = 20 | ||
| // => x = 10 | ||
| IntegerDomain output = IntegerDomain.of(25); | ||
| IntegerDomain afterSubtract = output.add(-5); // x = 20 | ||
| IntegerDomain solution = afterSubtract.multiply(1).intersect(IntegerDomain.of(0, 100)); | ||
|
|
||
| System.out.println("Equation: 2*x + 5 = 25"); | ||
| System.out.println("Output domain: " + output); | ||
| System.out.println("After subtracting 5: " + afterSubtract); | ||
| System.out.println("Solution (x must be in [0, 100]): " + solution); | ||
|
|
||
| // Verify | ||
| if (!solution.isEmpty()) { | ||
| long x = solution.sampleValues(1).get(0); | ||
| System.out.println("Sample x: " + x); | ||
| System.out.println("Verification: 2*" + x + " + 5 = " + (2 * x + 5)); | ||
| } | ||
| } | ||
|
|
||
| @Test | ||
| public void testMultiIntervalIntersection() { | ||
| System.out.println("\n=== Multi-Interval Intersection Test ==="); | ||
| List<IntegerDomain.Interval> intervals1 = | ||
| Arrays.asList(new IntegerDomain.Interval(1, 20), new IntegerDomain.Interval(30, 50)); | ||
| List<IntegerDomain.Interval> intervals2 = | ||
| Arrays.asList(new IntegerDomain.Interval(10, 35), new IntegerDomain.Interval(45, 60)); | ||
|
|
||
| IntegerDomain domain1 = IntegerDomain.of(intervals1); | ||
| IntegerDomain domain2 = IntegerDomain.of(intervals2); | ||
| IntegerDomain intersection = domain1.intersect(domain2); | ||
|
|
||
| System.out.println("Domain 1: " + domain1); | ||
| System.out.println("Domain 2: " + domain2); | ||
| System.out.println("Intersection: " + intersection); | ||
| System.out.println("Expected: [10, 20] ∪ [30, 35] ∪ [45, 50]"); | ||
| System.out.println("Samples: " + intersection.sampleValues(15)); | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These tests don't have assertions. Some other files have tests like these too.
| @Test | ||
| public void testArithmeticExpression() { | ||
| testDomainInference("Arithmetic Expression Test", "SELECT * FROM test.T WHERE age * 2 + 5 = 25", inputDomain -> { | ||
| assertTrue(inputDomain instanceof IntegerDomain, "Should be IntegerDomain"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When there is an error, a new test I'm adding still passes
@Test
public void testMultiVariateArithmeticExpression() {
testDomainInference("Arithmetic Expression Test", "SELECT * FROM test.suitcase WHERE width + height + length < 25", inputDomain -> {
assertTrue(inputDomain instanceof IntegerDomain, "Should be IntegerDomain");
IntegerDomain intDomain = (IntegerDomain) inputDomain;
System.out.println(intDomain);
assertTrue(intDomain.contains(10), "Should contain 10 (since 10 * 2 + 5 = 25)");
assertTrue(intDomain.contains(10), "Should contain 10 (since 10 * 2 + 5 = 25)");
assertTrue(intDomain.isSingleton(), "Should be singleton");
});
}
Introduce Symbolic Constraint Solver for SQL-Driven Data Generation
Overview
This PR introduces coral-data-generation, a symbolic constraint solver that inverts SQL expressions to derive input domain constraints. Instead of forward evaluation (generate → test → reject), it solves backward from predicates to derive what inputs must satisfy, enabling efficient test data generation with guaranteed constraint satisfaction.
Motivation
Problem: Traditional test data generation uses rejection sampling—generate random values, evaluate SQL predicates, discard mismatches. This is inefficient for complex nested expressions and cannot detect unsatisfiable queries.
Solution: Symbolic inversion treats SQL expressions as mathematical transformations with inverse functions. Starting from output constraints (e.g.,
= '50'), the system walks expression trees inward, applying inverse operations to derive input domains.Examples
1. Nested String Operations
2. Cross-Domain Arithmetic
3. Date Extraction with Type Casting
4. Complex Nested Substring
5. Contradiction Detection
6. Date String Pattern Matching
Key Components
1. Domain System
2. Transformer Architecture
Pluggable symbolic inversion functions implementing DomainTransformer:
SUBSTRING(x, start, len)with positional constraintsLOWER(x)via case-insensitive regex generationx + c = value→x = value - cx * c = value→x = value / c3. Relational Preprocessing
Normalizes Calcite RelNode trees for symbolic analysis:
4. Solver
DomainInferenceProgram: Top-down expression tree traversal with domain refinement at each step, detecting contradictions via empty domain intersection.
Technical Approach
Symbolic Inversion: For nested expression
f(g(h(x))) = constant:f⁻¹→ intermediate domaing⁻¹→ refined domainh⁻¹→ input constraint onxContradiction Detection: Multiple predicates on same variable → domain intersection. Empty result = unsatisfiable query.
Extensibility: Architecture supports multi-table inference (join propagation), fixed-point iteration (recursive constraints), and arbitrary domain types (date, decimal, enum).
Testing
Integration Tests (RegexDomainInferenceProgramTest): 14+ test scenarios covering simple/nested transformations, cross-domain CAST operations, arithmetic inversion, and contradiction detection. All tests validate generated samples satisfy original SQL predicates.
Documentation
This module comes with aomprehensive README with conceptual model, examples, and API reference.
Future Extensibility
The architecture naturally extends to additional domains (DecimalDomain, DateDomain), more transformers (CONCAT, REGEXP_EXTRACT), multi-table inference (join constraint propagation), and aggregate support (cardinality constraints).