|
12 | 12 |
|
13 | 13 | from sdv.errors import InvalidDataError |
14 | 14 | from sdv.metadata.errors import InvalidMetadataError |
| 15 | +from sdv.metadata.metadata import Metadata |
15 | 16 | from sdv.metadata.multi_table import MultiTableMetadata, SingleTableMetadata |
16 | 17 | from tests.utils import catch_sdv_logs, get_multi_table_data, get_multi_table_metadata |
17 | 18 |
|
@@ -2624,6 +2625,106 @@ def test__detect_relationships(self): |
2624 | 2625 | assert instance.relationships == expected_relationships |
2625 | 2626 | assert instance.tables['sessions'].columns['user_id']['sdtype'] == 'id' |
2626 | 2627 |
|
| 2628 | + def test__detect_relationships_semantic_foreign_key(self): |
| 2629 | + """Test semantic foreign keys are automatically detected without changing the sdtype.""" |
| 2630 | + # Setup |
| 2631 | + instance = Metadata.load_from_dict({ |
| 2632 | + 'tables': { |
| 2633 | + 'parent': { |
| 2634 | + 'primary_key': 'email', |
| 2635 | + 'columns': { |
| 2636 | + 'email': {'sdtype': 'email'}, |
| 2637 | + 'user_name': {'sdtype': 'categorical'}, |
| 2638 | + }, |
| 2639 | + }, |
| 2640 | + 'child': { |
| 2641 | + 'primary_key': 'child_id', |
| 2642 | + 'columns': { |
| 2643 | + 'child_id': {'sdtype': 'id'}, |
| 2644 | + 'email': {'sdtype': 'email', 'pii': True}, |
| 2645 | + }, |
| 2646 | + }, |
| 2647 | + }, |
| 2648 | + 'relationships': [], |
| 2649 | + }) |
| 2650 | + |
| 2651 | + # Run |
| 2652 | + instance._detect_relationships() |
| 2653 | + |
| 2654 | + # Assert |
| 2655 | + expected_relationships = [ |
| 2656 | + { |
| 2657 | + 'parent_table_name': 'parent', |
| 2658 | + 'child_table_name': 'child', |
| 2659 | + 'parent_primary_key': 'email', |
| 2660 | + 'child_foreign_key': 'email', |
| 2661 | + } |
| 2662 | + ] |
| 2663 | + assert instance.relationships == expected_relationships |
| 2664 | + assert instance.tables['child'].columns['email'] == {'sdtype': 'email', 'pii': True} |
| 2665 | + |
| 2666 | + def test__detect_relationships_semantic_foreign_key_does_not_overwrite_mismatch(self): |
| 2667 | + """Test semantic foreign key mismatches do not coerce the child sdtype.""" |
| 2668 | + # Setup |
| 2669 | + instance = Metadata.load_from_dict({ |
| 2670 | + 'tables': { |
| 2671 | + 'parent': { |
| 2672 | + 'primary_key': 'email', |
| 2673 | + 'columns': { |
| 2674 | + 'email': {'sdtype': 'email'}, |
| 2675 | + 'user_name': {'sdtype': 'categorical'}, |
| 2676 | + }, |
| 2677 | + }, |
| 2678 | + 'child': { |
| 2679 | + 'primary_key': 'child_id', |
| 2680 | + 'columns': { |
| 2681 | + 'child_id': {'sdtype': 'id'}, |
| 2682 | + 'email': {'sdtype': 'categorical'}, |
| 2683 | + }, |
| 2684 | + }, |
| 2685 | + }, |
| 2686 | + 'relationships': [], |
| 2687 | + }) |
| 2688 | + |
| 2689 | + # Run |
| 2690 | + instance._detect_relationships() |
| 2691 | + |
| 2692 | + # Assert |
| 2693 | + assert instance.relationships == [] |
| 2694 | + assert instance.tables['child'].columns['email']['sdtype'] == 'categorical' |
| 2695 | + |
| 2696 | + def test__detect_relationships_restores_foreign_key_metadata_after_failure(self): |
| 2697 | + """Test failed detection restores all original metadata values in the child foreign key.""" |
| 2698 | + # Setup |
| 2699 | + original_foreign_key_metadata = {'sdtype': 'email', 'pii': True} |
| 2700 | + instance = Metadata.load_from_dict({ |
| 2701 | + 'tables': { |
| 2702 | + 'users': { |
| 2703 | + 'primary_key': 'user_id', |
| 2704 | + 'columns': { |
| 2705 | + 'user_id': {'sdtype': 'id'}, |
| 2706 | + 'user_name': {'sdtype': 'categorical'}, |
| 2707 | + }, |
| 2708 | + }, |
| 2709 | + 'sessions': { |
| 2710 | + 'primary_key': 'session_id', |
| 2711 | + 'columns': { |
| 2712 | + 'user_id': original_foreign_key_metadata.copy(), |
| 2713 | + 'session_id': {'sdtype': 'id'}, |
| 2714 | + }, |
| 2715 | + }, |
| 2716 | + }, |
| 2717 | + 'relationships': [], |
| 2718 | + }) |
| 2719 | + instance.add_relationship = Mock(side_effect=InvalidMetadataError('bad relationship')) |
| 2720 | + |
| 2721 | + # Run |
| 2722 | + instance._detect_relationships() |
| 2723 | + |
| 2724 | + # Assert |
| 2725 | + instance.add_relationship.assert_called_once_with('users', 'sessions', 'user_id', 'user_id') |
| 2726 | + assert instance.tables['sessions'].columns['user_id'] == original_foreign_key_metadata |
| 2727 | + |
2627 | 2728 | def test__detect_relationships_circular(self): |
2628 | 2729 | """Test that relationships that invalidate the metadata are not added.""" |
2629 | 2730 | # Setup |
|
0 commit comments