Skip to content

Commit 4c15afc

Browse files
committed
Support USING COLUMNS syntax
1 parent 740e39a commit 4c15afc

File tree

6 files changed

+164
-22
lines changed

6 files changed

+164
-22
lines changed

dbt/adapters/databricks/relation_configs/column_mask.py

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from dataclasses import asdict
22
from typing import ClassVar, Optional
33

4+
from dbt_common.exceptions import DbtRuntimeError
5+
46
from dbt.adapters.contracts.relation import RelationConfig
57
from dbt.adapters.databricks.relation_configs.base import (
68
DatabricksComponentConfig,
@@ -10,11 +12,27 @@
1012

1113

1214
class ColumnMaskConfig(DatabricksComponentConfig):
13-
# column name -> mask
14-
set_column_masks: dict[str, str]
15+
# column name -> mask config (function name and optional using_columns)
16+
set_column_masks: dict[str, dict[str, str]]
1517
unset_column_masks: list[str] = []
1618

1719
def get_diff(self, other: "ColumnMaskConfig") -> Optional["ColumnMaskConfig"]:
20+
# Create a mapping of function names to their using_columns values
21+
function_using_columns = {}
22+
for mask in self.set_column_masks.values():
23+
function_using_columns[mask["function"]] = mask.get("using_columns")
24+
25+
# Check if any function's using_columns has changed
26+
for mask in other.set_column_masks.values():
27+
function = mask["function"]
28+
if function in function_using_columns and function_using_columns[function] != mask.get(
29+
"using_columns"
30+
):
31+
raise DbtRuntimeError(
32+
f"The value of using_columns for existing function {function} was updated. "
33+
f"This is not supported. Please use a new function with a different name."
34+
)
35+
1836
# Find column masks that need to be unset
1937
unset_column_mask = [
2038
col for col in other.set_column_masks if col not in self.set_column_masks
@@ -45,7 +63,11 @@ def from_relation_results(cls, results: RelationResults) -> ColumnMaskConfig:
4563

4664
if column_masks:
4765
for row in column_masks.rows:
48-
set_column_masks[row[0]] = row[1]
66+
# row contains [column_name, mask_name, using_columns]
67+
mask_config = {"function": row[1]}
68+
if row[2]:
69+
mask_config["using_columns"] = row[2]
70+
set_column_masks[row[0]] = mask_config
4971

5072
return ColumnMaskConfig(set_column_masks=set_column_masks)
5173

@@ -61,6 +83,17 @@ def from_relation_config(cls, relation_config: RelationConfig) -> ColumnMaskConf
6183
set_column_masks = {}
6284
for col in columns:
6385
extra = col.get("_extra", {})
64-
if extra and "column_mask" in extra:
65-
set_column_masks[col["name"]] = extra["column_mask"]
86+
column_mask = extra.get("column_mask") if extra else None
87+
if column_mask:
88+
fully_qualified_function_name = (
89+
column_mask["function"]
90+
if "." in column_mask["function"]
91+
else (
92+
f"{relation_config.database}."
93+
f"{relation_config.schema}."
94+
f"{column_mask['function']}"
95+
)
96+
)
97+
column_mask["function"] = fully_qualified_function_name
98+
set_column_masks[col["name"]] = column_mask
6699
return ColumnMaskConfig(set_column_masks=set_column_masks)

dbt/include/databricks/macros/relations/components/column_mask.sql

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
{% macro fetch_column_masks_sql(relation) -%}
1212
SELECT
1313
column_name,
14-
mask_name
14+
mask_name,
15+
using_columns
1516
FROM `{{ relation.database|lower }}`.information_schema.column_masks
1617
WHERE table_catalog = '{{ relation.database|lower }}'
1718
AND table_schema = '{{ relation.schema|lower }}'
@@ -63,7 +64,10 @@
6364
{% macro alter_set_column_mask(relation, column, mask) -%}
6465
ALTER TABLE {{ relation.render() }}
6566
ALTER COLUMN {{ column }}
66-
SET MASK {{ mask }};
67+
SET MASK {{ mask.function }}
68+
{%- if mask.using_columns %}
69+
USING COLUMNS ({{ mask.using_columns }})
70+
{%- endif %};
6771
{%- endmacro -%}
6872

6973
{% macro column_mask_exists() %}

tests/functional/adapter/column_masks/fixtures.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
{{ config(
33
materialized = 'table'
44
) }}
5-
SELECT 'abc-123' as id, 'password123' as password;
5+
SELECT 'abc123' as id, 'password123' as password;
66
"""
77

88

@@ -14,6 +14,21 @@
1414
- name: id
1515
data_type: string
1616
- name: password
17-
column_mask: password_mask
17+
column_mask:
18+
function: password_mask
1819
data_type: string
1920
"""
21+
22+
model_with_extra_args = """
23+
version: 2
24+
models:
25+
- name: base_model
26+
columns:
27+
- name: id
28+
data_type: string
29+
- name: password
30+
data_type: string
31+
column_mask:
32+
function: weird_mask
33+
using_columns: "id, 'literal_string', 333, true, null, INTERVAL 2 DAYS"
34+
"""

tests/functional/adapter/column_masks/test_column_mask.py

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import pytest
22

3-
from dbt.tests.util import run_dbt
3+
from dbt.tests.util import run_dbt, write_file
44
from tests.functional.adapter.column_masks.fixtures import (
55
base_model_sql,
66
model,
7+
model_with_extra_args,
78
)
89
from tests.functional.adapter.fixtures import MaterializationV2Mixin
910

@@ -17,7 +18,7 @@ def models(self):
1718
"schema.yml": model,
1819
}
1920

20-
def test_column_mask(self, project):
21+
def test_column_mask_no_extra_args(self, project):
2122
# Create the mask function
2223
project.run_sql(
2324
f"CREATE OR REPLACE FUNCTION {project.database}.{project.test_schema}."
@@ -31,6 +32,7 @@ def test_column_mask(self, project):
3132
f"""
3233
SELECT column_name, mask_name
3334
FROM {project.database}.information_schema.column_masks
35+
WHERE table_schema = '{project.test_schema}'
3436
""",
3537
fetch="all",
3638
)
@@ -41,9 +43,45 @@ def test_column_mask(self, project):
4143

4244
# Verify masked value
4345
result = project.run_sql("SELECT id, password FROM base_model", fetch="one")
44-
assert result[0] == "abc-123"
46+
assert result[0] == "abc123"
4547
assert result[1] == "*****" # Masked value should be 5 asterisks
4648

49+
def test_column_mask_with_extra_args(self, project):
50+
write_file(model_with_extra_args, "models", "schema.yml")
51+
52+
# Create a mask function that concatenates all possible arg types: original column, other
53+
# columns, and every supported literal type from
54+
# https://docs.databricks.com/aws/en/sql/language-manual/sql-ref-syntax-ddl-column-mask
55+
project.run_sql(
56+
f"""
57+
CREATE OR REPLACE FUNCTION {project.database}.{project.test_schema}.weird_mask(
58+
password STRING,
59+
id STRING,
60+
literal STRING,
61+
num INT,
62+
bool_val BOOLEAN,
63+
null_val STRING,
64+
interval INTERVAL DAY
65+
)
66+
RETURNS STRING
67+
RETURN CONCAT(
68+
password, '-',
69+
id, '-',
70+
literal, '-',
71+
CAST(num AS STRING), '-',
72+
CAST(bool_val AS STRING), '-',
73+
COALESCE(null_val, 'NULL'), '-',
74+
CAST(interval AS STRING)
75+
);
76+
"""
77+
)
78+
run_dbt(["run"])
79+
80+
# Not meant to resemble a real life example. Just for the sake of testing different types
81+
result = project.run_sql("SELECT id, password FROM base_model", fetch="one")
82+
assert result[0] == "abc123"
83+
assert result[1] == "password123-abc123-literal_string-333-true-NULL-INTERVAL '2' DAY"
84+
4785

4886
@pytest.mark.skip_profile("databricks_cluster")
4987
class TestIncrementalColumnMask(TestColumnMask):

tests/functional/adapter/incremental/fixtures.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -891,21 +891,23 @@ def model(dbt, spark):
891891
'password123' as password
892892
"""
893893

894-
column_mask_name = """
894+
column_mask_base = """
895895
version: 2
896896
897897
models:
898898
- name: column_mask_sql
899899
columns:
900900
- name: id
901901
- name: name
902-
column_mask: full_mask
902+
column_mask:
903+
function: full_mask
903904
- name: email
904-
column_mask: full_mask
905+
column_mask:
906+
function: full_mask
905907
- name: password
906908
"""
907909

908-
column_mask_password = """
910+
column_mask_valid_mask_updates = """
909911
version: 2
910912
911913
models:
@@ -914,7 +916,25 @@ def model(dbt, spark):
914916
- name: id
915917
- name: name
916918
- name: email
917-
column_mask: email_mask
919+
column_mask:
920+
function: email_mask
918921
- name: password
919-
column_mask: full_mask
922+
column_mask:
923+
function: full_mask
924+
"""
925+
926+
column_mask_invalid_update = """
927+
version: 2
928+
929+
models:
930+
- name: column_mask_sql
931+
columns:
932+
- name: id
933+
- name: name
934+
column_mask:
935+
- name: email
936+
- name: password
937+
column_mask:
938+
function: full_mask
939+
using_columns: "id"
920940
"""

tests/functional/adapter/incremental/test_incremental_column_masks.py

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,15 @@ class TestIncrementalColumnMasks(MaterializationV2Mixin):
1111
def models(self):
1212
return {
1313
"column_mask_sql.sql": fixtures.column_mask_sql,
14-
"schema.yml": fixtures.column_mask_name,
14+
"schema.yml": fixtures.column_mask_base,
1515
}
1616

1717
def test_changing_column_masks(self, project):
18-
# Create the mask functions
18+
# Create the full mask function
1919
project.run_sql(
2020
f"""
2121
CREATE OR REPLACE FUNCTION
22-
{project.database}.{project.test_schema}.full_mask(value STRING)
22+
{project.database}.{project.test_schema}.full_mask(password STRING)
2323
RETURNS STRING
2424
RETURN '*****';
2525
"""
@@ -50,7 +50,7 @@ def test_changing_column_masks(self, project):
5050
assert masks[0][3] == "password123" # password (unmasked)
5151

5252
# Update masks and verify changes
53-
util.write_file(fixtures.column_mask_password, "models", "schema.yml")
53+
util.write_file(fixtures.column_mask_valid_mask_updates, "models", "schema.yml")
5454
util.run_dbt(["run"])
5555

5656
result = project.run_sql(
@@ -61,3 +61,35 @@ def test_changing_column_masks(self, project):
6161
assert result[0][1] == "hello" # name (unmasked)
6262
assert result[0][2] == "********@example.com" # email (partially masked)
6363
assert result[0][3] == "*****" # password (masked)
64+
65+
66+
class TestInvalidColumnMaskUpdate(MaterializationV2Mixin):
67+
@pytest.fixture(scope="class")
68+
def models(self):
69+
return {
70+
"column_mask_sql.sql": fixtures.column_mask_sql,
71+
"schema.yml": fixtures.column_mask_base,
72+
}
73+
74+
def test_invalid_using_columns_update(self, project):
75+
# Create the full mask function
76+
project.run_sql(
77+
f"""
78+
CREATE OR REPLACE FUNCTION
79+
{project.database}.{project.test_schema}.full_mask(password STRING)
80+
RETURNS STRING
81+
RETURN '*****';
82+
"""
83+
)
84+
85+
# First run with name masked
86+
util.run_dbt(["run"])
87+
88+
# Trying to feed new arguments to using_columns on existing function should fail
89+
util.write_file(fixtures.column_mask_invalid_update, "models", "schema.yml")
90+
results = util.run_dbt(["run"], expect_pass=False)
91+
assert len(results.results) == 1
92+
assert (
93+
"This is not supported. Please use a new function with a different name."
94+
in results.results[0].message
95+
)

0 commit comments

Comments
 (0)