@@ -60,11 +60,33 @@ class S3ToRedshiftOperator(BaseOperator):
60
60
with. Only required if using a load_type of
61
61
"upsert".
62
62
:type incremental_key: string
63
+ :param foreign_key: *(optional)* This specifies any foreign_keys
64
+ in the table and which corresponding table
65
+ and key they reference. This may be either
66
+ a dictionary or list of dictionaries (for
67
+ multiple foreign keys). The fields that are
68
+ required in each dictionary are:
69
+ - column_name
70
+ - reftable
71
+ - ref_column
72
+ :type foreign_key: dictionary
73
+ :param distkey: *(optional)* The distribution key for the
74
+ table. Only one key may be specified.
75
+ :type distkey: string
76
+ :param sortkey: *(optional)* The sort keys for the table.
77
+ If more than one key is specified, set this
78
+ as a list.
79
+ :type sortkey: string
80
+ :param sort_type: *(optional)* The style of distribution
81
+ to sort the table. Possible values include:
82
+ - compound
83
+ - interleaved
84
+ Defaults to "compound".
85
+ :type sort_type: string
63
86
"""
64
87
65
88
template_fields = ('s3_key' ,
66
- 'origin_schema' ,
67
- 'com' )
89
+ 'origin_schema' )
68
90
69
91
@apply_defaults
70
92
def __init__ (self ,
@@ -81,7 +103,10 @@ def __init__(self,
81
103
load_type = 'append' ,
82
104
primary_key = None ,
83
105
incremental_key = None ,
84
- timeformat = 'auto' ,
106
+ foreign_key = {},
107
+ distkey = None ,
108
+ sortkey = '' ,
109
+ sort_type = 'COMPOUND' ,
85
110
* args ,
86
111
** kwargs ):
87
112
super ().__init__ (* args , ** kwargs )
@@ -98,14 +123,29 @@ def __init__(self,
98
123
self .load_type = load_type
99
124
self .primary_key = primary_key
100
125
self .incremental_key = incremental_key
101
- self .timeformat = timeformat
126
+ self .foreign_key = foreign_key
127
+ self .distkey = distkey
128
+ self .sortkey = sortkey
129
+ self .sort_type = sort_type
102
130
103
131
if self .load_type .lower () not in ["append" , "rebuild" , "upsert" ]:
104
132
raise Exception ('Please choose "append", "rebuild", or "upsert".' )
105
133
106
134
if self .schema_location .lower () not in ['s3' , 'local' ]:
107
135
raise Exception ('Valid Schema Locations are "s3" or "local".' )
108
136
137
+ if not (isinstance (self .sortkey , str ) or isinstance (self .sortkey , list )):
138
+ raise Exception ('Sort Keys must be specified as either a string or list.' )
139
+
140
+ if not (isinstance (self .foreign_key , dict ) or isinstance (self .foreign_key , list )):
141
+ raise Exception ('Foreign Keys must be specified as either a dictionary or a list of dictionaries.' )
142
+
143
+ if ((',' in self .distkey ) or not isinstance (self .distkey , str )):
144
+ raise Exception ('Only one distribution key may be specified.' )
145
+
146
+ if self .sort_type .lower () not in ('compound' , 'interleaved' ):
147
+ raise Exception ('Please choose "compound" or "interleaved" for sort type.' )
148
+
109
149
def execute (self , context ):
110
150
# Append a random string to the end of the staging table to ensure
111
151
# no conflicts if multiple processes running concurrently.
@@ -337,6 +377,8 @@ def create_if_not_exists(self, schema, pg_hook, temp=False):
337
377
for item in schema :
338
378
k = "{quote}{key}{quote}" .format (quote = '"' , key = item ['name' ])
339
379
field = ' ' .join ([k , item ['type' ]])
380
+ if isinstance (self .sortkey , str ) and self .sortkey == item ['name' ]:
381
+ field += ' sortkey'
340
382
output += field
341
383
output += ', '
342
384
# Remove last comma and space after schema items loop ends
@@ -346,12 +388,50 @@ def create_if_not_exists(self, schema, pg_hook, temp=False):
346
388
else :
347
389
copy_table = self .table
348
390
create_schema_query = \
349
- '''CREATE SCHEMA IF NOT EXISTS "{0}";''' .format (
350
- self .redshift_schema )
391
+ '''
392
+ CREATE SCHEMA IF NOT EXISTS "{0}";
393
+ ''' .format (self .redshift_schema )
394
+
395
+ pk = ''
396
+ fk = ''
397
+ dk = ''
398
+ sk = ''
399
+
400
+ if self .primary_key :
401
+ pk = ', primary key("{0}")' .format (self .primary_key )
402
+
403
+ if self .foreign_key :
404
+ if isinstance (self .foreign_key , list ):
405
+ fk = ', '
406
+ for i , e in enumerate (self .foreign_key ):
407
+ fk += 'foreign key("{0}") references {1}("{2}")' .format (e ['column_name' ],
408
+ e ['reftable' ],
409
+ e ['ref_column' ])
410
+ if i != (len (self .foreign_key ) - 1 ):
411
+ fk += ', ' ""
412
+ elif isinstance (self .foreign_key , dict ):
413
+ fk += ', '
414
+ fk += 'foreign key("{0}") references {1}("{2}")' .format (self .foreign_key ['column_name' ],
415
+ self .foreign_key ['reftable' ],
416
+ self .foreign_key ['ref_column' ])
417
+ if self .distkey :
418
+ dk = 'distkey({})' .format (self .distkey )
419
+
420
+ if self .sortkey :
421
+ if isinstance (self .sortkey , list ):
422
+ sk += '{0} sortkey({1})' .format (self .sort_type , ', ' .join (["{}" .format (e ) for e in self .sortkey ]))
423
+
351
424
create_table_query = \
352
- '''CREATE TABLE IF NOT EXISTS "{0}"."{1}" ({2})''' .format (
353
- self .redshift_schema ,
354
- copy_table ,
355
- output )
425
+ '''
426
+ CREATE TABLE IF NOT EXISTS "{schema}"."{table}"
427
+ ({fields}{primary_key}{foreign_key}) {distkey} {sortkey}
428
+ ''' .format (schema = self .redshift_schema ,
429
+ table = copy_table ,
430
+ fields = output ,
431
+ primary_key = pk ,
432
+ foreign_key = fk ,
433
+ distkey = dk ,
434
+ sortkey = sk )
435
+
356
436
pg_hook .run (create_schema_query )
357
437
pg_hook .run (create_table_query )
0 commit comments