diff --git a/app.py b/app.py index 1972852..e746e10 100755 --- a/app.py +++ b/app.py @@ -7,7 +7,6 @@ import sys import time -import psycopg2 import pyspark import logger @@ -20,23 +19,6 @@ def get_arg(env, default): return os.getenv(env) if os.getenv(env, '') is not '' else default -def make_connection(host='127.0.0.1', port=5432, user='postgres', - password='postgres', dbname='postgres'): - """Connect to a postgresql db.""" - return psycopg2.connect(host=host, port=port, user=user, - password=password, dbname=dbname) - - -def build_connection(args): - """Make the db connection with an args object.""" - conn = make_connection(host=args.host, - port=args.port, - user=args.user, - password=args.password, - dbname=args.dbname) - return conn - - def parse_args(parser): """Parsing command line args.""" args = parser.parse_args() @@ -96,15 +78,13 @@ def main(arguments): # set up SQL connection try: - con = build_connection(arguments) + data_loader = storage.PostgresDataLoader(arguments) except IOError: loggers.error("Could not connect to data store") sys.exit(1) # fetch the data from the db - cursor = con.cursor() - cursor.execute("SELECT * FROM ratings") - ratings = cursor.fetchall() + ratings = data_loader.fetchall() loggers.info("Fetched data from table") # create an RDD of the ratings data ratingsRDD = sc.parallelize(ratings) @@ -154,20 +134,14 @@ def main(arguments): # check to see if new model should be created # select the maximum time stamp from the ratings database - cursor.execute( - "SELECT timestamp FROM ratings ORDER BY timestamp DESC LIMIT 1;" - ) - checking_max_timestamp = cursor.fetchone()[0] + checking_max_timestamp = data_loader.latest_timestamp() loggers.info( "The latest timestamp = {}". format(checking_max_timestamp)) if checking_max_timestamp > max_timestamp: # build a new model # first, fetch all new ratings - cursor.execute( - "SELECT * FROM ratings WHERE (timestamp > %s);", - (max_timestamp,)) - new_ratings = cursor.fetchall() + new_ratings = data_loader.fetchafter(max_timestamp) max_timestamp = checking_max_timestamp new_ratingsRDD = sc.parallelize(new_ratings) new_ratingsRDD = new_ratingsRDD.map(lambda x: (x[1], x[2], x[3])) diff --git a/storage.py b/storage.py index c9952d3..82dae4b 100644 --- a/storage.py +++ b/storage.py @@ -1,6 +1,7 @@ from abc import ABCMeta, abstractmethod -from pymongo import MongoClient import datetime +import psycopg2 +from pymongo import MongoClient class ModelWriter: @@ -63,3 +64,88 @@ def write(self, model, version): 'model_id': version, 'id': feature[0], 'features': list(feature[1])}) + + +class DataLoader: + """ + Abstract class for a Data Store loader. + Implement backend specific loaders as a subclass. + """ + __metaclass__ = ABCMeta + + def __init__(self, arguments): + """ + :param arguments: The database specific connection arguments + """ + self._arguments = arguments + + @abstractmethod + def fetchall(self): + """ + Returns all ratings in the database. Each rating must be in form + `(userid, productid, rating) and in a collection type that Spark's + `parallelize` can accept (e.g. `List`) + :return: A collection of ratings, each with a user, product and rating. + """ + pass + + @abstractmethod + def latest_timestamp(self): + """ + Returns the timestamp for the most recent (chronologically) + rating in the database. + :return: A timestamp + """ + pass + + @abstractmethod + def fetchafter(self, timestamp): + """ + Returns all ratings (in the same format as `fetchall`) added after + `timestamp`. + :param timestamp: A timestamp + :return: A collection of ratings, each with a user, product and rating. + """ + pass + + +class PostgresDataLoader(DataLoader): + """ + Data store loader for a PostgreSQL backend. + """ + + def __init__(self, arguments): + super(PostgresDataLoader, self).__init__(arguments) + self._connection = self._build_connection(arguments) + self._cursor = self._connection.cursor() + + def _make_connection(self, host='127.0.0.1', port=5432, user='postgres', + password='postgres', dbname='postgres'): + """Connect to a postgresql db.""" + return psycopg2.connect(host=host, port=port, user=user, + password=password, dbname=dbname) + + def _build_connection(self, args): + """Make the db connection with an args object.""" + conn = self._make_connection(host=args.host, + port=args.port, + user=args.user, + password=args.password, + dbname=args.dbname) + return conn + + def fetchall(self): + self._cursor.execute("SELECT * FROM ratings") + return self._cursor.fetchall() + + def latest_timestamp(self): + self._cursor.execute( + "SELECT timestamp FROM ratings ORDER BY timestamp DESC LIMIT 1;" + ) + return self._cursor.fetchone()[0] + + def fetchafter(self, timestamp): + self._cursor.execute( + "SELECT * FROM ratings WHERE (timestamp > %s);", + (timestamp,)) + return self._cursor.fetchall()