Skip to content

Commit c53a9b7

Browse files
authored
Update constraint names generation
Fixes databricks#24
1 parent d90635a commit c53a9b7

File tree

1 file changed

+14
-2
lines changed

1 file changed

+14
-2
lines changed

pgsqlite/pgsqlite.py

+14-2
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,9 @@ def __init__(self, sqlite_filename: str, pg_conninfo: str, show_sample_data: boo
181181
self.max_import_concurrency = max_import_concurrency
182182
db = Database(self.sqlite_filename)
183183
self._tables = {t.name: ParsedTable(t) for t in db.tables}
184-
184+
self._constr_names = []
185+
self._constr_names_counter = 0
186+
185187
@property
186188
def tables(self):
187189
return self._tables.values()
@@ -237,6 +239,11 @@ def get_table_sql(self, table: ParsedTable) -> SQL:
237239
transpiled_pks_to_add = [table.get_transpiled_colname(pk) for pk in pks_to_add]
238240
all_column_sql = all_column_sql + SQL(",\n")
239241
pk_name = f"PK_{table.source_name}_" + ''.join(pks_to_add)
242+
if pk_name in self._constr_names:
243+
self._constr_names_counter += 1
244+
pk_name = f"{pk_name}_{self._constr_names_counter}"
245+
else:
246+
self._constr_names.append(pk_name)
240247
pk_sql = SQL(" CONSTRAINT {pk_name} PRIMARY KEY ({pks})").format(
241248
table_name=Identifier(table.transpiled_name),
242249
pk_name=Identifier(pk_name), pks=SQL(", ").join(
@@ -269,6 +276,11 @@ def get_fk_sql(self, table: ParsedTable) -> SQL:
269276
# create the foreign keys after the tables to avoid having to figure out the dep graph
270277
for fk in table.src_table.foreign_keys:
271278
fk_name = f"FK_{fk.other_table}_{fk.other_column}"
279+
if fk_name in self._constr_names:
280+
self._constr_names_counter += 1
281+
fk_name = f"{fk_name}_{self._constr_names_counter}"
282+
else:
283+
self._constr_names.append(fk_name)
272284
fk_sql = SQL("ALTER TABLE {table_name} ADD CONSTRAINT {key_name} FOREIGN KEY ({column}) REFERENCES {other_table} ({other_column})").format(
273285
table_name=Identifier(table.transpiled_name),
274286
column=Identifier(table.get_transpiled_colname(fk.column)),
@@ -586,4 +598,4 @@ async def create_all_indexes():
586598
logger.debug(json.dumps(loader.get_summary(), indent=2))
587599

588600
if args.drop_tables_after_import:
589-
loader._drop_tables()
601+
loader._drop_tables()

0 commit comments

Comments
 (0)