Skip to content

Commit 085beaa

Browse files
committed
Fix remove dataset
1 parent 9d61cb6 commit 085beaa

File tree

2 files changed

+112
-3
lines changed

2 files changed

+112
-3
lines changed

sdv/metadata/metadata.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,22 @@ def _remove_matching_relationships(self, element, keys):
335335

336336
self.relationships = updated_relationships
337337

338+
def _remove_column_relationships(self, table_name, column_name):
339+
"""Remove relationships where the column is a key for the given table."""
340+
updated_relationships = []
341+
for relationship in self.relationships:
342+
should_remove = (
343+
relationship['child_foreign_key'] == column_name
344+
and relationship['child_table_name'] == table_name
345+
) or (
346+
relationship['parent_primary_key'] == column_name
347+
and relationship['parent_table_name'] == table_name
348+
)
349+
if not should_remove:
350+
updated_relationships.append(relationship)
351+
352+
self.relationships = updated_relationships
353+
338354
def remove_table(self, table_name):
339355
"""Remove a table from the metadata.
340356
@@ -380,9 +396,7 @@ def remove_column(self, column_name, table_name=None):
380396
table_metadata._validate_column_exists(column_name)
381397

382398
# Remove relationships
383-
self._remove_matching_relationships(
384-
column_name, ['parent_primary_key', 'child_foreign_key']
385-
)
399+
self._remove_column_relationships(table_name, column_name)
386400
updated_column_relationships = []
387401
for column_relationship in table_metadata.column_relationships:
388402
if column_name not in column_relationship.get('column_names', []):

tests/unit/metadata/test_metadata.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1045,6 +1045,101 @@ def test_remove_column_removes_relationships(self):
10451045
assert list(manufacturer_mock.columns.keys()) == ['country', 'id']
10461046
assert metadata._multi_table_updated
10471047

1048+
def test__remove_column_relationships_only_removes_matching_table(self):
1049+
"""Test that only relationships for the given table and column are removed."""
1050+
# Setup
1051+
metadata = Metadata()
1052+
metadata.relationships = [
1053+
{
1054+
'parent_table_name': 'parent',
1055+
'parent_primary_key': 'id',
1056+
'child_table_name': 'child_a',
1057+
'child_foreign_key': 'fk',
1058+
},
1059+
{
1060+
'parent_table_name': 'parent',
1061+
'parent_primary_key': 'id',
1062+
'child_table_name': 'child_b',
1063+
'child_foreign_key': 'fk',
1064+
},
1065+
]
1066+
1067+
# Run
1068+
metadata._remove_column_relationships('child_a', 'fk')
1069+
1070+
# Assert
1071+
assert metadata.relationships == [
1072+
{
1073+
'parent_table_name': 'parent',
1074+
'parent_primary_key': 'id',
1075+
'child_table_name': 'child_b',
1076+
'child_foreign_key': 'fk',
1077+
},
1078+
]
1079+
1080+
def test_remove_column_only_removes_relationship_for_that_table(self):
1081+
"""Test removing a foreign key column only removes the relationship for that table."""
1082+
# Setup
1083+
metadata = Metadata.load_from_dict({
1084+
'tables': {
1085+
'table1': {
1086+
'primary_key': 'id',
1087+
'columns': {
1088+
'id': {'sdtype': 'id'},
1089+
'A': {'sdtype': 'numerical'},
1090+
'B': {'sdtype': 'categorical'},
1091+
},
1092+
},
1093+
'table2': {
1094+
'primary_key': 'id',
1095+
'columns': {
1096+
'id': {'sdtype': 'id'},
1097+
'fk_1': {'sdtype': 'id'},
1098+
'A': {'sdtype': 'numerical'},
1099+
'B': {'sdtype': 'categorical'},
1100+
},
1101+
},
1102+
'table3': {
1103+
'primary_key': 'id',
1104+
'columns': {
1105+
'id': {'sdtype': 'id'},
1106+
'fk_1': {'sdtype': 'id'},
1107+
'A': {'sdtype': 'numerical'},
1108+
'B': {'sdtype': 'categorical'},
1109+
},
1110+
},
1111+
},
1112+
'relationships': [
1113+
{
1114+
'parent_table_name': 'table1',
1115+
'parent_primary_key': 'id',
1116+
'child_table_name': 'table2',
1117+
'child_foreign_key': 'fk_1',
1118+
},
1119+
{
1120+
'parent_table_name': 'table1',
1121+
'parent_primary_key': 'id',
1122+
'child_table_name': 'table3',
1123+
'child_foreign_key': 'fk_1',
1124+
},
1125+
],
1126+
})
1127+
1128+
# Run
1129+
metadata.remove_column('fk_1', 'table2')
1130+
1131+
# Assert
1132+
assert metadata.relationships == [
1133+
{
1134+
'parent_table_name': 'table1',
1135+
'parent_primary_key': 'id',
1136+
'child_table_name': 'table3',
1137+
'child_foreign_key': 'fk_1',
1138+
},
1139+
]
1140+
assert 'fk_1' not in metadata.tables['table2'].columns
1141+
assert 'fk_1' in metadata.tables['table3'].columns
1142+
10481143
def test_remove_column_sequence_key(self):
10491144
"""Test the method also remove the sequence key if the column is one."""
10501145
# Setup

0 commit comments

Comments
 (0)