Skip to content
This repository was archived by the owner on May 1, 2024. It is now read-only.

Commit ceda01d

Browse files
with rdd
1 parent 2b7b94b commit ceda01d

File tree

3 files changed

+99
-9
lines changed

3 files changed

+99
-9
lines changed

edx/analytics/tasks/common/spark.py

+40
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
ManifestInputTargetMixin, convert_to_manifest_input_if_necessary, remove_manifest_target_if_exists
1515
)
1616
from edx.analytics.tasks.util.overwrite import OverwriteOutputMixin
17+
from edx.analytics.tasks.util.spark_util import load_and_filter
1718
from edx.analytics.tasks.util.url import UncheckedExternalURL, get_target_from_url, url_path_join
1819

1920
_file_path_to_package_meta_path = {}
@@ -201,6 +202,14 @@ class EventLogSelectionMixinSpark(EventLogSelectionDownstreamMixin):
201202
description='Whether or not to process event log source directly with spark',
202203
default=False
203204
)
205+
cache_rdd = luigi.BoolParameter(
206+
description="Whether to cache rdd or not",
207+
default=False
208+
)
209+
rdd_checkpoint_directory = luigi.Parameter(
210+
description="Path to directory where rdd can be checkpointed",
211+
default=None
212+
)
204213

205214
def __init__(self, *args, **kwargs):
206215
"""
@@ -275,6 +284,37 @@ def get_event_log_dataframe(self, spark, *args, **kwargs):
275284
)
276285
return dataframe
277286

287+
def get_user_location_schema(self):
288+
from pyspark.sql.types import StructType, StringType, IntegerType
289+
schema = StructType().add("user_id", IntegerType(), True) \
290+
.add("course_id", StringType(), True) \
291+
.add("ip", StringType(), True) \
292+
.add("timestamp", StringType(), True) \
293+
.add("event_date", StringType(), True)
294+
295+
def get_dataframe(self, spark, *args, **kwargs):
296+
from pyspark.sql.functions import to_date, udf, struct, date_format
297+
input_source = self.get_input_source(*args)
298+
user_location_schema = self.get_user_location_schema()
299+
master_rdd = spark.sparkContext.union(
300+
# filter out unwanted data as much as possible within each rdd before union
301+
map(
302+
lambda target: load_and_filter(spark, target.path, self.lower_bound_date_string,
303+
self.upper_bound_date_string),
304+
input_source
305+
)
306+
)
307+
if self.rdd_checkpoint_directory:
308+
# set checkpoint location before checkpointing
309+
spark.sparkContext.setCheckpointDir(self.rdd_checkpoint_directory)
310+
master_rdd.localCheckpoint()
311+
if self.cache_rdd:
312+
master_rdd.cache()
313+
dataframe = spark.createDataFrame(master_rdd, schema=user_location_schema)
314+
if 'user_id' not in dataframe.columns: # rename columns if they weren't named properly by createDataFrame
315+
dataframe = dataframe.toDF('user_id', 'course_id', 'ip', 'timestamp', 'event_date')
316+
return dataframe
317+
278318

279319
class SparkJobTask(SparkMixin, OverwriteOutputMixin, EventLogSelectionDownstreamMixin, PySparkTask):
280320
"""

edx/analytics/tasks/insights/location_per_course.py

+7-9
Original file line numberDiff line numberDiff line change
@@ -217,25 +217,23 @@ def run(self):
217217
super(LastDailyIpAddressOfUserTaskSpark, self).run()
218218

219219
def spark_job(self, *args):
220-
from edx.analytics.tasks.util.spark_util import get_event_predicate_labels, get_course_id, get_event_time_string
220+
from edx.analytics.tasks.util.spark_util import validate_course_id
221221
from pyspark.sql.functions import udf
222222
from pyspark.sql.window import Window
223223
from pyspark.sql.types import StringType
224-
df = self.get_event_log_dataframe(self._spark, *args)
225-
get_event_time = udf(get_event_time_string, StringType())
226-
get_courseid = udf(get_course_id, StringType())
227-
df = df.withColumn('course_id', get_courseid(df['context'])) \
228-
.withColumn('timestamp', get_event_time(df['time']))
224+
df = self.get_dataframe(self._spark, *args)
225+
validate_courseid = udf(validate_course_id, StringType())
226+
df = df.withColumn('course_id', validate_courseid(df['course_id']))
229227
df.createOrReplaceTempView('location')
230228
query = """
231229
SELECT
232230
timestamp, ip, user_id, course_id, dt
233231
FROM (
234232
SELECT
235-
event_date as dt, context.user_id as user_id, course_id, timestamp, ip,
236-
ROW_NUMBER() over ( PARTITION BY event_date, context.user_id, course_id ORDER BY timestamp desc) as rank
233+
event_date as dt, user_id, course_id, timestamp, ip,
234+
ROW_NUMBER() over ( PARTITION BY event_date, user_id, course_id ORDER BY timestamp desc) as rank
237235
FROM location
238-
WHERE ip <> '' AND timestamp <> '' AND context.user_id <> ''
236+
WHERE ip <> ''
239237
) user_location
240238
WHERE rank = 1
241239
"""

edx/analytics/tasks/util/spark_util.py

+52
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
"""Support for spark tasks"""
2+
import json
3+
import re
4+
25
import edx.analytics.tasks.util.opaque_key_util as opaque_key_util
36
from edx.analytics.tasks.util.constants import PredicateLabels
47

8+
PATTERN_JSON = re.compile(r'^.*?(\{.*\})\s*$')
59

610
def get_event_predicate_labels(event_type, event_source):
711
"""
@@ -53,6 +57,54 @@ def get_event_time_string(event_time):
5357
return ''
5458

5559

60+
def filter_event_logs(row, lower_bound_date_string, upper_bound_date_string):
61+
if row is None:
62+
return ()
63+
context = row.get('context', '')
64+
raw_time = row.get('time', '')
65+
if not context or not raw_time:
66+
return ()
67+
course_id = context.get('course_id', '').encode('utf-8')
68+
user_id = context.get('user_id', None)
69+
time = get_event_time_string(raw_time).encode('utf-8')
70+
ip = row.get('ip', '').encode('utf-8')
71+
if not user_id or not time:
72+
return ()
73+
date_string = raw_time.split("T")[0].encode('utf-8')
74+
if date_string < lower_bound_date_string or date_string >= upper_bound_date_string:
75+
return () # discard events outside the date interval
76+
return (user_id, course_id, ip, time, date_string)
77+
78+
79+
def parse_json_event(line, nested=False):
80+
"""
81+
Parse a tracking log input line as JSON to create a dict representation.
82+
"""
83+
try:
84+
parsed = json.loads(line)
85+
except Exception:
86+
if not nested:
87+
json_match = PATTERN_JSON.match(line)
88+
if json_match:
89+
return parse_json_event(json_match.group(1), nested=True)
90+
return None
91+
return parsed
92+
93+
94+
def load_and_filter(spark_session, file, lower_bound_date_string, upper_bound_date_string):
95+
return spark_session.sparkContext.textFile(file) \
96+
.map(parse_json_event) \
97+
.map(lambda row: filter_event_logs(row, lower_bound_date_string, upper_bound_date_string)) \
98+
.filter(bool)
99+
100+
101+
def validate_course_id(course_id):
102+
course_id = opaque_key_util.normalize_course_id(course_id)
103+
if course_id:
104+
if opaque_key_util.is_valid_course_id(course_id):
105+
return course_id
106+
return ''
107+
56108
def get_course_id(event_context, from_url=False):
57109
"""
58110
Gets course_id from event's data.

0 commit comments

Comments
 (0)