diff --git a/edx/analytics/tasks/common/spark.py b/edx/analytics/tasks/common/spark.py index 63eb311024..a9039b6c57 100644 --- a/edx/analytics/tasks/common/spark.py +++ b/edx/analytics/tasks/common/spark.py @@ -197,6 +197,10 @@ class EventLogSelectionMixinSpark(EventLogSelectionDownstreamMixin): Extract events corresponding to a specified time interval. """ + direct_eventlogs_processing = luigi.BoolParameter( + description='Whether or not to process event log source directly with spark', + default=False + ) def __init__(self, *args, **kwargs): """ Call path selection task to get list of log files matching the pattern @@ -242,33 +246,34 @@ def get_log_schema(self): return event_log_schema def get_input_source(self, *args): - manifest_path = self.get_config_from_args('manifest_path', *args, default_value='') - targets = PathSelectionTaskSpark( + return PathSelectionByDateIntervalTask( source=self.source, interval=self.interval, pattern=self.pattern, - date_pattern=self.date_pattern, - manifest_id=self.manifest_id, - manifest_dir=manifest_path, + date_pattern=self.date_pattern ).output() - if len(targets) and 'manifest' in targets[0].path: - # Reading manifest as rdd with spark is alot faster as compared to hadoop. - # Currently, we're getting only 1 manifest file per request, so we will create a single rdd from it. - # If there are multiple manifest files, each file can be read as rdd and then union it with other manifest rdds - self.log.warn("PYSPARK LOGGER: Reading manifest file :: {} ".format(targets[0].path)) - source_rdd = self._spark.sparkContext.textFile(targets[0].path) - broadcast_value = self._spark.sparkContext.broadcast(source_rdd.collect()) - else: - self.log.warn("PYSPARK LOGGER: Reading normal targets") - broadcast_value = self._spark.sparkContext.broadcast([target.path for target in targets]) - return broadcast_value def get_event_log_dataframe(self, spark, *args, **kwargs): from pyspark.sql.functions import to_date, udf, struct, date_format - dataframe = spark.read.format('json').load( - self.get_input_source(*args).value, - schema=self.get_log_schema() - ) + schema = self.get_log_schema() + if self.direct_eventlogs_processing: + self.log.warn("\nPYSPARK => Processing event log source directly\n") + event_log_source = self.get_config_from_args('event_log_source', *args, default_value=None) + if event_log_source is not None: + event_log_source = json.loads(event_log_source) + self.log.warn("\nPYSPARK => Event log source : {}\n".format(event_log_source)) + dataframe = spark.read.format('json').load(event_log_source[0], schema=self.get_log_schema()) + source_list_count = len(event_log_source) + if source_list_count > 1: + for k in range(1, source_list_count): + dataframe = dataframe.union( + spark.read.format('json').load(event_log_source[k], schema=self.get_log_schema()) + ) + else: + self.log.warn("\nPYSPARK => Processing path selection output\n") + input_source = self.get_input_source(*args) + path_targets = [target.path for target in input_source] + dataframe = spark.read.format('json').load(path_targets, schema=self.get_log_schema()) dataframe = dataframe.filter(dataframe['time'].isNotNull()) \ .withColumn('event_date', date_format(to_date(dataframe['time']), 'yyyy-MM-dd')) dataframe = dataframe.filter( @@ -310,26 +315,6 @@ def conf(self): """ return self._dict_config(self.spark_conf) - @property - def manifest_id(self): - params = { - 'source': self.source, - 'interval': self.interval, - 'pattern': self.pattern, - 'date_pattern': self.date_pattern, - 'spark': 'for_some_difference_with_hadoop_manifest' - } - return str(hash(frozenset(params.items()))).replace('-', 'n') - - def get_manifest_path(self, *args): - manifest_path = self.get_config_from_args('manifest_path', *args, default_value='') - return get_target_from_url( - url_path_join( - manifest_path, - self.manifest_id + '.manifest' - ) - ) - def spark_job(self): """ Spark code for the job diff --git a/edx/analytics/tasks/insights/location_per_course.py b/edx/analytics/tasks/insights/location_per_course.py index 1641ecd8d4..2ad1241335 100644 --- a/edx/analytics/tasks/insights/location_per_course.py +++ b/edx/analytics/tasks/insights/location_per_course.py @@ -219,7 +219,7 @@ def run(self): def get_luigi_configuration(self): options = {} config = luigi.configuration.get_config() - options['manifest_path'] = config.get('manifest', 'path', '') + options['event_log_source'] = config.get('event-logs', 'source', '') return options def spark_job(self, *args): @@ -246,7 +246,7 @@ def spark_job(self, *args): WHERE rank = 1 """ result = self._spark.sql(query) - result.coalesce(4).write.partitionBy('dt').csv(self.output_dir().path, mode='append', sep='\t') + result.coalesce(2).write.partitionBy('dt').csv(self.output_dir().path, mode='append', sep='\t') class LastCountryOfUserDownstreamMixin(