Skip to content

Commit 509440e

Browse files
authored
Apply code quality formatting rules (#140)
1 parent c1355a7 commit 509440e

File tree

28 files changed

+331
-327
lines changed

28 files changed

+331
-327
lines changed

.bumpversion.cfg

+2-3
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,13 @@ parse = (?P<major>\d+)
88
((?P<prerelease>[a-z]+)
99
?(\.)?
1010
(?P<num>\d+))?
11-
serialize =
11+
serialize =
1212
{major}.{minor}.{patch}{prerelease}{num}
1313
{major}.{minor}.{patch}
1414

1515
[bumpversion:part:prerelease]
1616
first_value = a
17-
values =
17+
values =
1818
a
1919
b
2020
rc
@@ -33,4 +33,3 @@ replace = version = "{new_version}"
3333
[bumpversion:file:dbt/adapters/mariadb/__version__.py]
3434
search = version = "{current_version}"
3535
replace = version = "{new_version}"
36-

MANIFEST.in

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
recursive-include dbt/include *.sql *.yml *.md
1+
recursive-include dbt/include *.sql *.yml *.md

dbt/adapters/mariadb/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,5 @@
1111
Plugin = AdapterPlugin(
1212
adapter=MariaDBAdapter,
1313
credentials=MariaDBCredentials,
14-
include_path=mariadb.PACKAGE_PATH)
14+
include_path=mariadb.PACKAGE_PATH,
15+
)

dbt/adapters/mariadb/column.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from dbt.adapters.base.column import Column
55

6-
Self = TypeVar('Self', bound='MariaDBColumn')
6+
Self = TypeVar("Self", bound="MariaDBColumn")
77

88

99
@dataclass
@@ -18,7 +18,7 @@ class MariaDBColumn(Column):
1818

1919
@property
2020
def quoted(self) -> str:
21-
return '`{}`'.format(self.column)
21+
return "`{}`".format(self.column)
2222

2323
def __repr__(self) -> str:
2424
return "<MariaDBColumn {} ({})>".format(self.name, self.data_type)

dbt/adapters/mariadb/connections.py

+16-18
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,7 @@ def __init__(self, **kwargs):
3939

4040
def __post_init__(self):
4141
# Database and schema are treated as the same thing
42-
if (
43-
self.database is not None and
44-
self.database != self.schema
45-
):
42+
if self.database is not None and self.database != self.schema:
4643
raise dbt.exceptions.RuntimeException(
4744
f" schema: {self.schema} \n"
4845
f" database: {self.database} \n"
@@ -76,8 +73,8 @@ class MariaDBConnectionManager(SQLConnectionManager):
7673

7774
@classmethod
7875
def open(cls, connection):
79-
if connection.state == 'open':
80-
logger.debug('Connection is already open, skipping open.')
76+
if connection.state == "open":
77+
logger.debug("Connection is already open, skipping open.")
8178
return connection
8279

8380
credentials = cls.get_credentials(connection.credentials)
@@ -96,26 +93,29 @@ def open(cls, connection):
9693

9794
try:
9895
connection.handle = mysql.connector.connect(**kwargs)
99-
connection.state = 'open'
96+
connection.state = "open"
10097
except mysql.connector.Error:
10198

10299
try:
103-
logger.debug("Failed connection without supplying the `database`. "
104-
"Trying again with `database` included.")
100+
logger.debug(
101+
"Failed connection without supplying the `database`. "
102+
"Trying again with `database` included."
103+
)
105104

106105
# Try again with the database included
107106
kwargs["database"] = credentials.schema
108107

109108
connection.handle = mysql.connector.connect(**kwargs)
110-
connection.state = 'open'
109+
connection.state = "open"
111110
except mysql.connector.Error as e:
112111

113-
logger.debug("Got an error when attempting to open a MariaDB "
114-
"connection: '{}'"
115-
.format(e))
112+
logger.debug(
113+
"Got an error when attempting to open a MariaDB "
114+
"connection: '{}'".format(e)
115+
)
116116

117117
connection.handle = None
118-
connection.state = 'fail'
118+
connection.state = "fail"
119119

120120
raise dbt.exceptions.FailedToConnectException(str(e))
121121

@@ -134,7 +134,7 @@ def exception_handler(self, sql):
134134
yield
135135

136136
except mysql.connector.DatabaseError as e:
137-
logger.debug('MariaDB error: {}'.format(str(e)))
137+
logger.debug("MariaDB error: {}".format(str(e)))
138138

139139
try:
140140
self.rollback_if_open()
@@ -167,7 +167,5 @@ def get_response(cls, cursor) -> AdapterResponse:
167167
# There's no real way to get the status from the mysql-connector-python driver.
168168
# So just return the default value.
169169
return AdapterResponse(
170-
_message="{} {}".format(code, num_rows),
171-
rows_affected=num_rows,
172-
code=code
170+
_message="{} {}".format(code, num_rows), rows_affected=num_rows, code=code
173171
)

dbt/adapters/mariadb/impl.py

+61-58
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818

1919
logger = AdapterLogger("mysql")
2020

21-
LIST_SCHEMAS_MACRO_NAME = 'list_schemas'
22-
LIST_RELATIONS_MACRO_NAME = 'list_relations_without_caching'
21+
LIST_SCHEMAS_MACRO_NAME = "list_schemas"
22+
LIST_RELATIONS_MACRO_NAME = "list_relations_without_caching"
2323

2424

2525
class MariaDBAdapter(SQLAdapter):
@@ -29,28 +29,23 @@ class MariaDBAdapter(SQLAdapter):
2929

3030
@classmethod
3131
def date_function(cls):
32-
return 'current_date()'
32+
return "current_date()"
3333

3434
@classmethod
35-
def convert_datetime_type(
36-
cls, agate_table: agate.Table, col_idx: int
37-
) -> str:
35+
def convert_datetime_type(cls, agate_table: agate.Table, col_idx: int) -> str:
3836
return "timestamp"
3937

4038
def quote(self, identifier):
41-
return '`{}`'.format(identifier)
39+
return "`{}`".format(identifier)
4240

4341
def list_relations_without_caching(
4442
self, schema_relation: MariaDBRelation
4543
) -> List[MariaDBRelation]:
46-
kwargs = {'schema_relation': schema_relation}
44+
kwargs = {"schema_relation": schema_relation}
4745
try:
48-
results = self.execute_macro(
49-
LIST_RELATIONS_MACRO_NAME,
50-
kwargs=kwargs
51-
)
46+
results = self.execute_macro(LIST_RELATIONS_MACRO_NAME, kwargs=kwargs)
5247
except dbt.exceptions.RuntimeException as e:
53-
errmsg = getattr(e, 'msg', '')
48+
errmsg = getattr(e, "msg", "")
5449
if f"MariaDB database '{schema_relation}' not found" in errmsg:
5550
return []
5651
else:
@@ -64,13 +59,11 @@ def list_relations_without_caching(
6459
raise dbt.exceptions.RuntimeException(
6560
"Invalid value from "
6661
f'"mariadb__list_relations_without_caching({kwargs})", '
67-
f'got {len(row)} values, expected 4'
62+
f"got {len(row)} values, expected 4"
6863
)
6964
_, name, _schema, relation_type = row
7065
relation = self.Relation.create(
71-
schema=_schema,
72-
identifier=name,
73-
type=relation_type
66+
schema=_schema, identifier=name, type=relation_type
7467
)
7568
relations.append(relation)
7669

@@ -88,9 +81,9 @@ def _get_columns_for_catalog(
8881
for column in columns:
8982
# convert MariaDBColumns into catalog dicts
9083
as_dict = asdict(column)
91-
as_dict['column_name'] = as_dict.pop('column', None)
92-
as_dict['column_type'] = as_dict.pop('dtype')
93-
as_dict['table_database'] = None
84+
as_dict["column_name"] = as_dict.pop("column", None)
85+
as_dict["column_type"] = as_dict.pop("dtype")
86+
as_dict["table_database"] = None
9487
yield as_dict
9588

9689
def get_relation(
@@ -102,48 +95,58 @@ def get_relation(
10295
return super().get_relation(database, schema, identifier)
10396

10497
def parse_show_columns(
105-
self,
106-
relation: Relation,
107-
raw_rows: List[agate.Row]
98+
self, relation: Relation, raw_rows: List[agate.Row]
10899
) -> List[MariaDBColumn]:
109-
return [MariaDBColumn(
110-
table_database=None,
111-
table_schema=relation.schema,
112-
table_name=relation.name,
113-
table_type=relation.type,
114-
table_owner=None,
115-
table_stats=None,
116-
column=column.column,
117-
column_index=idx,
118-
dtype=column.dtype,
119-
) for idx, column in enumerate(raw_rows)]
100+
return [
101+
MariaDBColumn(
102+
table_database=None,
103+
table_schema=relation.schema,
104+
table_name=relation.name,
105+
table_type=relation.type,
106+
table_owner=None,
107+
table_stats=None,
108+
column=column.column,
109+
column_index=idx,
110+
dtype=column.dtype,
111+
)
112+
for idx, column in enumerate(raw_rows)
113+
]
120114

121115
def get_catalog(self, manifest):
122116
schema_map = self._get_catalog_schemas(manifest)
123117
if len(schema_map) > 1:
124118
dbt.exceptions.raise_compiler_error(
125-
f'Expected only one database in get_catalog, found '
126-
f'{list(schema_map)}'
119+
f"Expected only one database in get_catalog, found "
120+
f"{list(schema_map)}"
127121
)
128122

129123
with executor(self.config) as tpe:
130124
futures: List[Future[agate.Table]] = []
131125
for info, schemas in schema_map.items():
132126
for schema in schemas:
133-
futures.append(tpe.submit_connected(
134-
self, schema,
135-
self._get_one_catalog, info, [schema], manifest
136-
))
127+
futures.append(
128+
tpe.submit_connected(
129+
self,
130+
schema,
131+
self._get_one_catalog,
132+
info,
133+
[schema],
134+
manifest,
135+
)
136+
)
137137
catalogs, exceptions = catch_as_completed(futures)
138138
return catalogs, exceptions
139139

140140
def _get_one_catalog(
141-
self, information_schema, schemas, manifest,
141+
self,
142+
information_schema,
143+
schemas,
144+
manifest,
142145
) -> agate.Table:
143146
if len(schemas) != 1:
144147
dbt.exceptions.raise_compiler_error(
145-
f'Expected only one schema in mariadb _get_one_catalog, found '
146-
f'{schemas}'
148+
f"Expected only one schema in mariadb _get_one_catalog, found "
149+
f"{schemas}"
147150
)
148151

149152
database = information_schema.database
@@ -153,14 +156,11 @@ def _get_one_catalog(
153156
for relation in self.list_relations(database, schema):
154157
logger.debug("Getting table schema for relation {}", relation)
155158
columns.extend(self._get_columns_for_catalog(relation))
156-
return agate.Table.from_object(
157-
columns, column_types=DEFAULT_TYPE_TESTER
158-
)
159+
return agate.Table.from_object(columns, column_types=DEFAULT_TYPE_TESTER)
159160

160161
def check_schema_exists(self, database, schema):
161162
results = self.execute_macro(
162-
LIST_SCHEMAS_MACRO_NAME,
163-
kwargs={'database': database}
163+
LIST_SCHEMAS_MACRO_NAME, kwargs={"database": database}
164164
)
165165

166166
exists = True if schema in [row[0] for row in results] else False
@@ -174,13 +174,13 @@ def update_column_sql(
174174
clause: str,
175175
where_clause: Optional[str] = None,
176176
) -> str:
177-
clause = f'update {dst_name} set {dst_column} = {clause}'
177+
clause = f"update {dst_name} set {dst_column} = {clause}"
178178
if where_clause is not None:
179-
clause += f' where {where_clause}'
179+
clause += f" where {where_clause}"
180180
return clause
181181

182182
def timestamp_add_sql(
183-
self, add_to: str, number: int = 1, interval: str = 'hour'
183+
self, add_to: str, number: int = 1, interval: str = "hour"
184184
) -> str:
185185
# for backwards compatibility, we're compelled to set some sort of
186186
# default. A lot of searching has lead me to believe that the
@@ -189,11 +189,14 @@ def timestamp_add_sql(
189189
return f"date_add({add_to}, interval {number} {interval})"
190190

191191
def string_add_sql(
192-
self, add_to: str, value: str, location='append',
192+
self,
193+
add_to: str,
194+
value: str,
195+
location="append",
193196
) -> str:
194-
if location == 'append':
197+
if location == "append":
195198
return f"concat({add_to}, '{value}')"
196-
elif location == 'prepend':
199+
elif location == "prepend":
197200
return f"concat({value}, '{add_to}')"
198201
else:
199202
raise dbt.exceptions.RuntimeException(
@@ -216,15 +219,15 @@ def get_rows_different_sql(
216219

217220
alias_a = "A"
218221
alias_b = "B"
219-
columns_csv_a = ', '.join([f"{alias_a}.{name}" for name in names])
220-
columns_csv_b = ', '.join([f"{alias_b}.{name}" for name in names])
222+
columns_csv_a = ", ".join([f"{alias_a}.{name}" for name in names])
223+
columns_csv_b = ", ".join([f"{alias_b}.{name}" for name in names])
221224
join_condition = " AND ".join(
222225
[f"{alias_a}.{name} = {alias_b}.{name}" for name in names]
223226
)
224227
first_column = names[0]
225228

226229
# There is no EXCEPT or MINUS operator, so we need to simulate it
227-
COLUMNS_EQUAL_SQL = '''
230+
COLUMNS_EQUAL_SQL = """
228231
SELECT
229232
row_count_diff.difference as row_count_difference,
230233
diff_count.num_missing as num_mismatched
@@ -259,7 +262,7 @@ def get_rows_different_sql(
259262
260263
) as missing
261264
) as diff_count ON row_count_diff.id = diff_count.id
262-
'''.strip()
265+
""".strip()
263266

264267
sql = COLUMNS_EQUAL_SQL.format(
265268
alias_a=alias_a,

dbt/adapters/mariadb/relation.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class MariaDBIncludePolicy(Policy):
2222
class MariaDBRelation(BaseRelation):
2323
quote_policy: MariaDBQuotePolicy = MariaDBQuotePolicy()
2424
include_policy: MariaDBIncludePolicy = MariaDBIncludePolicy()
25-
quote_character: str = '`'
25+
quote_character: str = "`"
2626

2727
def __post_init__(self):
2828
if self.database != self.schema and self.database:

dbt/adapters/mysql/__init__.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,5 @@
99

1010

1111
Plugin = AdapterPlugin(
12-
adapter=MySQLAdapter,
13-
credentials=MySQLCredentials,
14-
include_path=mysql.PACKAGE_PATH)
12+
adapter=MySQLAdapter, credentials=MySQLCredentials, include_path=mysql.PACKAGE_PATH
13+
)

0 commit comments

Comments
 (0)