Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added abstract DataLoader to enable simple backend implementations #38

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 4 additions & 30 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import sys
import time

import psycopg2
import pyspark

import logger
Expand All @@ -20,23 +19,6 @@ def get_arg(env, default):
return os.getenv(env) if os.getenv(env, '') is not '' else default


def make_connection(host='127.0.0.1', port=5432, user='postgres',
password='postgres', dbname='postgres'):
"""Connect to a postgresql db."""
return psycopg2.connect(host=host, port=port, user=user,
password=password, dbname=dbname)


def build_connection(args):
"""Make the db connection with an args object."""
conn = make_connection(host=args.host,
port=args.port,
user=args.user,
password=args.password,
dbname=args.dbname)
return conn


def parse_args(parser):
"""Parsing command line args."""
args = parser.parse_args()
Expand Down Expand Up @@ -96,15 +78,13 @@ def main(arguments):

# set up SQL connection
try:
con = build_connection(arguments)
data_loader = storage.PostgresDataLoader(arguments)
except IOError:
loggers.error("Could not connect to data store")
sys.exit(1)

# fetch the data from the db
cursor = con.cursor()
cursor.execute("SELECT * FROM ratings")
ratings = cursor.fetchall()
ratings = data_loader.fetchall()
loggers.info("Fetched data from table")
# create an RDD of the ratings data
ratingsRDD = sc.parallelize(ratings)
Expand Down Expand Up @@ -154,20 +134,14 @@ def main(arguments):

# check to see if new model should be created
# select the maximum time stamp from the ratings database
cursor.execute(
"SELECT timestamp FROM ratings ORDER BY timestamp DESC LIMIT 1;"
)
checking_max_timestamp = cursor.fetchone()[0]
checking_max_timestamp = data_loader.latest_timestamp()
loggers.info(
"The latest timestamp = {}". format(checking_max_timestamp))

if checking_max_timestamp > max_timestamp:
# build a new model
# first, fetch all new ratings
cursor.execute(
"SELECT * FROM ratings WHERE (timestamp > %s);",
(max_timestamp,))
new_ratings = cursor.fetchall()
new_ratings = data_loader.fetchafter(max_timestamp)
max_timestamp = checking_max_timestamp
new_ratingsRDD = sc.parallelize(new_ratings)
new_ratingsRDD = new_ratingsRDD.map(lambda x: (x[1], x[2], x[3]))
Expand Down
88 changes: 87 additions & 1 deletion storage.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from abc import ABCMeta, abstractmethod
from pymongo import MongoClient
import datetime
import psycopg2
from pymongo import MongoClient


class ModelWriter:
Expand Down Expand Up @@ -63,3 +64,88 @@ def write(self, model, version):
'model_id': version,
'id': feature[0],
'features': list(feature[1])})


class DataLoader:
"""
Abstract class for a Data Store loader.
Implement backend specific loaders as a subclass.
"""
__metaclass__ = ABCMeta

def __init__(self, arguments):
"""
:param arguments: The database specific connection arguments
"""
self._arguments = arguments

@abstractmethod
def fetchall(self):
"""
Returns all ratings in the database. Each rating must be in form
`(userid, productid, rating) and in a collection type that Spark's
`parallelize` can accept (e.g. `List`)
:return: A collection of ratings, each with a user, product and rating.
"""
pass

@abstractmethod
def latest_timestamp(self):
"""
Returns the timestamp for the most recent (chronologically)
rating in the database.
:return: A timestamp
"""
pass

@abstractmethod
def fetchafter(self, timestamp):
"""
Returns all ratings (in the same format as `fetchall`) added after
`timestamp`.
:param timestamp: A timestamp
:return: A collection of ratings, each with a user, product and rating.
"""
pass


class PostgresDataLoader(DataLoader):
"""
Data store loader for a PostgreSQL backend.
"""

def __init__(self, arguments):
super(PostgresDataLoader, self).__init__(arguments)
self._connection = self._build_connection(arguments)
self._cursor = self._connection.cursor()

def _make_connection(self, host='127.0.0.1', port=5432, user='postgres',
password='postgres', dbname='postgres'):
"""Connect to a postgresql db."""
return psycopg2.connect(host=host, port=port, user=user,
password=password, dbname=dbname)

def _build_connection(self, args):
"""Make the db connection with an args object."""
conn = self._make_connection(host=args.host,
port=args.port,
user=args.user,
password=args.password,
dbname=args.dbname)
return conn

def fetchall(self):
self._cursor.execute("SELECT * FROM ratings")
return self._cursor.fetchall()

def latest_timestamp(self):
self._cursor.execute(
"SELECT timestamp FROM ratings ORDER BY timestamp DESC LIMIT 1;"
)
return self._cursor.fetchone()[0]

def fetchafter(self, timestamp):
self._cursor.execute(
"SELECT * FROM ratings WHERE (timestamp > %s);",
(timestamp,))
return self._cursor.fetchall()