Skip to content

Commit fc2efa4

Browse files
committed
Update detection
1 parent 7c57499 commit fc2efa4

File tree

2 files changed

+121
-9
lines changed

2 files changed

+121
-9
lines changed

sdv/metadata/multi_table.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -543,25 +543,36 @@ def _detect_foreign_keys_by_column_name(self, data):
543543
"""
544544
for parent_candidate in self.tables.keys():
545545
primary_key = self.tables[parent_candidate].primary_key
546+
if primary_key is None:
547+
continue
548+
549+
pk_sdtype = self.tables[parent_candidate].columns[primary_key]['sdtype']
546550
for child_candidate in self.tables.keys() - {parent_candidate}:
547551
child_meta = self.tables[child_candidate]
548552
if primary_key in child_meta.columns.keys():
553+
original_fk_meta = deepcopy(child_meta.columns[primary_key])
554+
original_fk_sdtype = original_fk_meta['sdtype']
555+
if pk_sdtype != 'id' and original_fk_sdtype != pk_sdtype:
556+
continue
557+
549558
try:
550-
original_foreign_key_sdtype = child_meta.columns[primary_key]['sdtype']
551-
if original_foreign_key_sdtype != 'id':
559+
if pk_sdtype == 'id' and original_fk_sdtype != 'id':
552560
self.update_column(
553-
table_name=child_candidate, column_name=primary_key, sdtype='id'
561+
table_name=child_candidate,
562+
column_name=primary_key,
563+
sdtype='id',
554564
)
555-
556565
self.add_relationship(
557566
parent_candidate, child_candidate, primary_key, primary_key
558567
)
568+
559569
except InvalidMetadataError:
560-
self.update_column(
561-
table_name=child_candidate,
562-
column_name=primary_key,
563-
sdtype=original_foreign_key_sdtype,
564-
)
570+
if pk_sdtype == 'id' and original_fk_sdtype != 'id':
571+
self.update_column(
572+
table_name=child_candidate,
573+
column_name=primary_key,
574+
**original_fk_meta,
575+
)
565576
continue
566577

567578
def _detect_relationships(self, data=None, foreign_key_inference_algorithm='column_name_match'):

tests/unit/metadata/test_multi_table.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
from sdv.errors import InvalidDataError
1414
from sdv.metadata.errors import InvalidMetadataError
15+
from sdv.metadata.metadata import Metadata
1516
from sdv.metadata.multi_table import MultiTableMetadata, SingleTableMetadata
1617
from tests.utils import catch_sdv_logs, get_multi_table_data, get_multi_table_metadata
1718

@@ -2624,6 +2625,106 @@ def test__detect_relationships(self):
26242625
assert instance.relationships == expected_relationships
26252626
assert instance.tables['sessions'].columns['user_id']['sdtype'] == 'id'
26262627

2628+
def test__detect_relationships_semantic_foreign_key(self):
2629+
"""Test semantic foreign keys are automatically detected without changing the sdtype."""
2630+
# Setup
2631+
instance = Metadata.load_from_dict({
2632+
'tables': {
2633+
'parent': {
2634+
'primary_key': 'email',
2635+
'columns': {
2636+
'email': {'sdtype': 'email'},
2637+
'user_name': {'sdtype': 'categorical'},
2638+
},
2639+
},
2640+
'child': {
2641+
'primary_key': 'child_id',
2642+
'columns': {
2643+
'child_id': {'sdtype': 'id'},
2644+
'email': {'sdtype': 'email', 'pii': True},
2645+
},
2646+
},
2647+
},
2648+
'relationships': [],
2649+
})
2650+
2651+
# Run
2652+
instance._detect_relationships()
2653+
2654+
# Assert
2655+
expected_relationships = [
2656+
{
2657+
'parent_table_name': 'parent',
2658+
'child_table_name': 'child',
2659+
'parent_primary_key': 'email',
2660+
'child_foreign_key': 'email',
2661+
}
2662+
]
2663+
assert instance.relationships == expected_relationships
2664+
assert instance.tables['child'].columns['email'] == {'sdtype': 'email', 'pii': True}
2665+
2666+
def test__detect_relationships_semantic_foreign_key_does_not_overwrite_mismatch(self):
2667+
"""Test semantic foreign key mismatches do not coerce the child sdtype."""
2668+
# Setup
2669+
instance = Metadata.load_from_dict({
2670+
'tables': {
2671+
'parent': {
2672+
'primary_key': 'email',
2673+
'columns': {
2674+
'email': {'sdtype': 'email'},
2675+
'user_name': {'sdtype': 'categorical'},
2676+
},
2677+
},
2678+
'child': {
2679+
'primary_key': 'child_id',
2680+
'columns': {
2681+
'child_id': {'sdtype': 'id'},
2682+
'email': {'sdtype': 'categorical'},
2683+
},
2684+
},
2685+
},
2686+
'relationships': [],
2687+
})
2688+
2689+
# Run
2690+
instance._detect_relationships()
2691+
2692+
# Assert
2693+
assert instance.relationships == []
2694+
assert instance.tables['child'].columns['email']['sdtype'] == 'categorical'
2695+
2696+
def test__detect_relationships_restores_foreign_key_metadata_after_failure(self):
2697+
"""Test failed detection restores all original metadata values in the child foreign key."""
2698+
# Setup
2699+
original_foreign_key_metadata = {'sdtype': 'email', 'pii': True}
2700+
instance = Metadata.load_from_dict({
2701+
'tables': {
2702+
'users': {
2703+
'primary_key': 'user_id',
2704+
'columns': {
2705+
'user_id': {'sdtype': 'id'},
2706+
'user_name': {'sdtype': 'categorical'},
2707+
},
2708+
},
2709+
'sessions': {
2710+
'primary_key': 'session_id',
2711+
'columns': {
2712+
'user_id': original_foreign_key_metadata.copy(),
2713+
'session_id': {'sdtype': 'id'},
2714+
},
2715+
},
2716+
},
2717+
'relationships': [],
2718+
})
2719+
instance.add_relationship = Mock(side_effect=InvalidMetadataError('bad relationship'))
2720+
2721+
# Run
2722+
instance._detect_relationships()
2723+
2724+
# Assert
2725+
instance.add_relationship.assert_called_once_with('users', 'sessions', 'user_id', 'user_id')
2726+
assert instance.tables['sessions'].columns['user_id'] == original_foreign_key_metadata
2727+
26272728
def test__detect_relationships_circular(self):
26282729
"""Test that relationships that invalidate the metadata are not added."""
26292730
# Setup

0 commit comments

Comments
 (0)