Skip to content

Commit b1a2de5

Browse files
add truncate option separate from rebuild option
1 parent b275fd4 commit b1a2de5

File tree

1 file changed

+28
-12
lines changed

1 file changed

+28
-12
lines changed

Diff for: operators/s3_to_redshift.py

+28-12
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,14 @@
22
import random
33
import string
44
import logging
5+
6+
from airflow.utils.db import provide_session
7+
from airflow.models import Connection
58
from airflow.utils.decorators import apply_defaults
9+
610
from airflow.models import BaseOperator
711
from airflow.hooks.S3_hook import S3Hook
812
from airflow.hooks.postgres_hook import PostgresHook
9-
from airflow.utils.db import provide_session
10-
from airflow.models import Connection
1113

1214

1315
class S3ToRedshiftOperator(BaseOperator):
@@ -46,9 +48,12 @@ class S3ToRedshiftOperator(BaseOperator):
4648
possible values include "mysql".
4749
:type origin_datatype: string
4850
:param load_type: The method of loading into Redshift that
49-
should occur. Options are "append",
50-
"rebuild", and "upsert". Defaults to
51-
"append."
51+
should occur. Options:
52+
- "append"
53+
- "rebuild"
54+
- "truncate"
55+
- "upsert"
56+
Defaults to "append."
5257
:type load_type: string
5358
:param primary_key: *(optional)* The primary key for the
5459
destination table. Not enforced by redshift
@@ -128,10 +133,10 @@ def __init__(self,
128133
self.sortkey = sortkey
129134
self.sort_type = sort_type
130135

131-
if self.load_type.lower() not in ["append", "rebuild", "upsert"]:
136+
if self.load_type.lower() not in ("append", "rebuild", "truncate", "upsert"):
132137
raise Exception('Please choose "append", "rebuild", or "upsert".')
133138

134-
if self.schema_location.lower() not in ['s3', 'local']:
139+
if self.schema_location.lower() not in ('s3', 'local'):
135140
raise Exception('Valid Schema Locations are "s3" or "local".')
136141

137142
if not (isinstance(self.sortkey, str) or isinstance(self.sortkey, list)):
@@ -152,9 +157,12 @@ def execute(self, context):
152157
letters = string.ascii_lowercase
153158
random_string = ''.join(random.choice(letters) for _ in range(7))
154159
self.temp_suffix = '_tmp_{0}'.format(random_string)
160+
155161
if self.origin_schema:
156162
schema = self.read_and_format()
163+
157164
pg_hook = PostgresHook(self.redshift_conn_id)
165+
158166
self.create_if_not_exists(schema, pg_hook)
159167
self.reconcile_schemas(schema, pg_hook)
160168
self.copy_data(pg_hook, schema)
@@ -221,7 +229,6 @@ def read_and_format(self):
221229
if i['type'] == e['avro']:
222230
i['type'] = e['redshift']
223231

224-
print(schema)
225232
return schema
226233

227234
def reconcile_schemas(self, schema, pg_hook):
@@ -277,7 +284,7 @@ def getS3Conn():
277284
elif aws_role_arn:
278285
creds = ("aws_iam_role={0}"
279286
.format(aws_role_arn))
280-
287+
281288
return creds
282289

283290
# Delete records from the destination table where the incremental_key
@@ -331,6 +338,11 @@ def getS3Conn():
331338
FILLTARGET
332339
'''.format(self.redshift_schema, self.table, self.temp_suffix)
333340

341+
drop_sql = \
342+
'''
343+
DROP TABLE IF EXISTS "{0}"."{1}"
344+
'''.format(self.redshift_schema, self.table)
345+
334346
drop_temp_sql = \
335347
'''
336348
DROP TABLE IF EXISTS "{0}"."{1}{2}"
@@ -366,6 +378,13 @@ def getS3Conn():
366378
base_sql)
367379
if self.load_type == 'append':
368380
pg_hook.run(load_sql)
381+
elif self.load_type == 'rebuild':
382+
pg_hook.run(drop_sql)
383+
self.create_if_not_exists(schema, pg_hook)
384+
pg_hook.run(load_sql)
385+
elif self.load_type == 'truncate':
386+
pg_hook.run(truncate_sql)
387+
pg_hook.run(load_sql)
369388
elif self.load_type == 'upsert':
370389
self.create_if_not_exists(schema, pg_hook, temp=True)
371390
load_temp_sql = \
@@ -378,9 +397,6 @@ def getS3Conn():
378397
pg_hook.run(delete_confirm_sql)
379398
pg_hook.run(append_sql, autocommit=True)
380399
pg_hook.run(drop_temp_sql)
381-
elif self.load_type == 'rebuild':
382-
pg_hook.run(truncate_sql)
383-
pg_hook.run(load_sql)
384400

385401
def create_if_not_exists(self, schema, pg_hook, temp=False):
386402
output = ''

0 commit comments

Comments
 (0)