18
18
19
19
logger = AdapterLogger ("mysql" )
20
20
21
- LIST_SCHEMAS_MACRO_NAME = ' list_schemas'
22
- LIST_RELATIONS_MACRO_NAME = ' list_relations_without_caching'
21
+ LIST_SCHEMAS_MACRO_NAME = " list_schemas"
22
+ LIST_RELATIONS_MACRO_NAME = " list_relations_without_caching"
23
23
24
24
25
25
class MariaDBAdapter (SQLAdapter ):
@@ -29,28 +29,23 @@ class MariaDBAdapter(SQLAdapter):
29
29
30
30
@classmethod
31
31
def date_function (cls ):
32
- return ' current_date()'
32
+ return " current_date()"
33
33
34
34
@classmethod
35
- def convert_datetime_type (
36
- cls , agate_table : agate .Table , col_idx : int
37
- ) -> str :
35
+ def convert_datetime_type (cls , agate_table : agate .Table , col_idx : int ) -> str :
38
36
return "timestamp"
39
37
40
38
def quote (self , identifier ):
41
- return ' `{}`' .format (identifier )
39
+ return " `{}`" .format (identifier )
42
40
43
41
def list_relations_without_caching (
44
42
self , schema_relation : MariaDBRelation
45
43
) -> List [MariaDBRelation ]:
46
- kwargs = {' schema_relation' : schema_relation }
44
+ kwargs = {" schema_relation" : schema_relation }
47
45
try :
48
- results = self .execute_macro (
49
- LIST_RELATIONS_MACRO_NAME ,
50
- kwargs = kwargs
51
- )
46
+ results = self .execute_macro (LIST_RELATIONS_MACRO_NAME , kwargs = kwargs )
52
47
except dbt .exceptions .RuntimeException as e :
53
- errmsg = getattr (e , ' msg' , '' )
48
+ errmsg = getattr (e , " msg" , "" )
54
49
if f"MariaDB database '{ schema_relation } ' not found" in errmsg :
55
50
return []
56
51
else :
@@ -64,13 +59,11 @@ def list_relations_without_caching(
64
59
raise dbt .exceptions .RuntimeException (
65
60
"Invalid value from "
66
61
f'"mariadb__list_relations_without_caching({ kwargs } )", '
67
- f' got { len (row )} values, expected 4'
62
+ f" got { len (row )} values, expected 4"
68
63
)
69
64
_ , name , _schema , relation_type = row
70
65
relation = self .Relation .create (
71
- schema = _schema ,
72
- identifier = name ,
73
- type = relation_type
66
+ schema = _schema , identifier = name , type = relation_type
74
67
)
75
68
relations .append (relation )
76
69
@@ -88,9 +81,9 @@ def _get_columns_for_catalog(
88
81
for column in columns :
89
82
# convert MariaDBColumns into catalog dicts
90
83
as_dict = asdict (column )
91
- as_dict [' column_name' ] = as_dict .pop (' column' , None )
92
- as_dict [' column_type' ] = as_dict .pop (' dtype' )
93
- as_dict [' table_database' ] = None
84
+ as_dict [" column_name" ] = as_dict .pop (" column" , None )
85
+ as_dict [" column_type" ] = as_dict .pop (" dtype" )
86
+ as_dict [" table_database" ] = None
94
87
yield as_dict
95
88
96
89
def get_relation (
@@ -102,48 +95,58 @@ def get_relation(
102
95
return super ().get_relation (database , schema , identifier )
103
96
104
97
def parse_show_columns (
105
- self ,
106
- relation : Relation ,
107
- raw_rows : List [agate .Row ]
98
+ self , relation : Relation , raw_rows : List [agate .Row ]
108
99
) -> List [MariaDBColumn ]:
109
- return [MariaDBColumn (
110
- table_database = None ,
111
- table_schema = relation .schema ,
112
- table_name = relation .name ,
113
- table_type = relation .type ,
114
- table_owner = None ,
115
- table_stats = None ,
116
- column = column .column ,
117
- column_index = idx ,
118
- dtype = column .dtype ,
119
- ) for idx , column in enumerate (raw_rows )]
100
+ return [
101
+ MariaDBColumn (
102
+ table_database = None ,
103
+ table_schema = relation .schema ,
104
+ table_name = relation .name ,
105
+ table_type = relation .type ,
106
+ table_owner = None ,
107
+ table_stats = None ,
108
+ column = column .column ,
109
+ column_index = idx ,
110
+ dtype = column .dtype ,
111
+ )
112
+ for idx , column in enumerate (raw_rows )
113
+ ]
120
114
121
115
def get_catalog (self , manifest ):
122
116
schema_map = self ._get_catalog_schemas (manifest )
123
117
if len (schema_map ) > 1 :
124
118
dbt .exceptions .raise_compiler_error (
125
- f' Expected only one database in get_catalog, found '
126
- f' { list (schema_map )} '
119
+ f" Expected only one database in get_catalog, found "
120
+ f" { list (schema_map )} "
127
121
)
128
122
129
123
with executor (self .config ) as tpe :
130
124
futures : List [Future [agate .Table ]] = []
131
125
for info , schemas in schema_map .items ():
132
126
for schema in schemas :
133
- futures .append (tpe .submit_connected (
134
- self , schema ,
135
- self ._get_one_catalog , info , [schema ], manifest
136
- ))
127
+ futures .append (
128
+ tpe .submit_connected (
129
+ self ,
130
+ schema ,
131
+ self ._get_one_catalog ,
132
+ info ,
133
+ [schema ],
134
+ manifest ,
135
+ )
136
+ )
137
137
catalogs , exceptions = catch_as_completed (futures )
138
138
return catalogs , exceptions
139
139
140
140
def _get_one_catalog (
141
- self , information_schema , schemas , manifest ,
141
+ self ,
142
+ information_schema ,
143
+ schemas ,
144
+ manifest ,
142
145
) -> agate .Table :
143
146
if len (schemas ) != 1 :
144
147
dbt .exceptions .raise_compiler_error (
145
- f' Expected only one schema in mariadb _get_one_catalog, found '
146
- f' { schemas } '
148
+ f" Expected only one schema in mariadb _get_one_catalog, found "
149
+ f" { schemas } "
147
150
)
148
151
149
152
database = information_schema .database
@@ -153,14 +156,11 @@ def _get_one_catalog(
153
156
for relation in self .list_relations (database , schema ):
154
157
logger .debug ("Getting table schema for relation {}" , relation )
155
158
columns .extend (self ._get_columns_for_catalog (relation ))
156
- return agate .Table .from_object (
157
- columns , column_types = DEFAULT_TYPE_TESTER
158
- )
159
+ return agate .Table .from_object (columns , column_types = DEFAULT_TYPE_TESTER )
159
160
160
161
def check_schema_exists (self , database , schema ):
161
162
results = self .execute_macro (
162
- LIST_SCHEMAS_MACRO_NAME ,
163
- kwargs = {'database' : database }
163
+ LIST_SCHEMAS_MACRO_NAME , kwargs = {"database" : database }
164
164
)
165
165
166
166
exists = True if schema in [row [0 ] for row in results ] else False
@@ -174,13 +174,13 @@ def update_column_sql(
174
174
clause : str ,
175
175
where_clause : Optional [str ] = None ,
176
176
) -> str :
177
- clause = f' update { dst_name } set { dst_column } = { clause } '
177
+ clause = f" update { dst_name } set { dst_column } = { clause } "
178
178
if where_clause is not None :
179
- clause += f' where { where_clause } '
179
+ clause += f" where { where_clause } "
180
180
return clause
181
181
182
182
def timestamp_add_sql (
183
- self , add_to : str , number : int = 1 , interval : str = ' hour'
183
+ self , add_to : str , number : int = 1 , interval : str = " hour"
184
184
) -> str :
185
185
# for backwards compatibility, we're compelled to set some sort of
186
186
# default. A lot of searching has lead me to believe that the
@@ -189,11 +189,14 @@ def timestamp_add_sql(
189
189
return f"date_add({ add_to } , interval { number } { interval } )"
190
190
191
191
def string_add_sql (
192
- self , add_to : str , value : str , location = 'append' ,
192
+ self ,
193
+ add_to : str ,
194
+ value : str ,
195
+ location = "append" ,
193
196
) -> str :
194
- if location == ' append' :
197
+ if location == " append" :
195
198
return f"concat({ add_to } , '{ value } ')"
196
- elif location == ' prepend' :
199
+ elif location == " prepend" :
197
200
return f"concat({ value } , '{ add_to } ')"
198
201
else :
199
202
raise dbt .exceptions .RuntimeException (
@@ -216,15 +219,15 @@ def get_rows_different_sql(
216
219
217
220
alias_a = "A"
218
221
alias_b = "B"
219
- columns_csv_a = ', ' .join ([f"{ alias_a } .{ name } " for name in names ])
220
- columns_csv_b = ', ' .join ([f"{ alias_b } .{ name } " for name in names ])
222
+ columns_csv_a = ", " .join ([f"{ alias_a } .{ name } " for name in names ])
223
+ columns_csv_b = ", " .join ([f"{ alias_b } .{ name } " for name in names ])
221
224
join_condition = " AND " .join (
222
225
[f"{ alias_a } .{ name } = { alias_b } .{ name } " for name in names ]
223
226
)
224
227
first_column = names [0 ]
225
228
226
229
# There is no EXCEPT or MINUS operator, so we need to simulate it
227
- COLUMNS_EQUAL_SQL = '''
230
+ COLUMNS_EQUAL_SQL = """
228
231
SELECT
229
232
row_count_diff.difference as row_count_difference,
230
233
diff_count.num_missing as num_mismatched
@@ -259,7 +262,7 @@ def get_rows_different_sql(
259
262
260
263
) as missing
261
264
) as diff_count ON row_count_diff.id = diff_count.id
262
- ''' .strip ()
265
+ """ .strip ()
263
266
264
267
sql = COLUMNS_EQUAL_SQL .format (
265
268
alias_a = alias_a ,
0 commit comments