Skip to content

Commit

Permalink
Modified predict to compute representations just once; added predict_…
Browse files Browse the repository at this point in the history
…score with same input as predict_rank, tests
  • Loading branch information
paoloRais committed Aug 25, 2016
1 parent c90bbf9 commit 5cf8e21
Show file tree
Hide file tree
Showing 5 changed files with 2,280 additions and 566 deletions.
121 changes: 100 additions & 21 deletions lightfm/_lightfm_fast.pyx.template
Original file line number Diff line number Diff line change
Expand Up @@ -1139,51 +1139,129 @@ def fit_bpr(CSRMatrix item_features,
user_alpha)


def predict_lightfm(CSRMatrix item_features,
CSRMatrix user_features,
int[::1] user_ids,
int[::1] item_ids,
double[::1] predictions,
FastLightFM lightfm,
int num_threads):
cdef precompute_unique(CSRMatrix item_features,
CSRMatrix user_features,
int[::1] unique_users,
int[::1] unique_items,
flt *user_reprs,
flt *it_reprs,
FastLightFM lightfm,
int num_threads):
"""
Generate predictions.
Precomputes the representations for all the users in unique_users and
all the items in unique_items
"""

cdef int i, no_examples
cdef flt *user_repr
cdef int i, j
cdef flt *it_repr
cdef flt *user_repr
cdef int no_features
cdef int no_users

no_examples = predictions.shape[0]

no_features = unique_items.shape[0]
no_users = unique_users.shape[0]
{nogil_block}

user_repr = <flt *>malloc(sizeof(flt) * (lightfm.no_components + 1))
it_repr = <flt *>malloc(sizeof(flt) * (lightfm.no_components + 1))

for i in {range_block}(no_examples):

# users representations
for i in {range_block}(no_users):
compute_representation(user_features,
lightfm.user_features,
lightfm.user_biases,
lightfm,
user_ids[i],
unique_users[i],
lightfm.user_scale,
user_repr)
for j in {range_block}(lightfm.no_components + 1):
user_reprs[i * (lightfm.no_components + 1) + j] = user_repr[j]

# items representations
for i in {range_block}(no_features):
compute_representation(item_features,
lightfm.item_features,
lightfm.item_biases,
lightfm,
item_ids[i],
unique_items[i],
lightfm.item_scale,
it_repr)
for j in {range_block}(lightfm.no_components + 1):
it_reprs[i * (lightfm.no_components + 1) + j] = it_repr[j]


def predict_lightfm(CSRMatrix item_features,
CSRMatrix user_features,
int[::1] user_ids,
int[::1] item_ids,
double[::1] predictions,
FastLightFM lightfm,
int num_threads,
bint precompute):
"""
Generate predictions.
"""
cdef int i, j, no_examples
cdef flt *user_repr
cdef flt *it_repr
cdef flt *user_reprs
cdef flt *it_reprs
cdef int[::1] unique_users
cdef int[::1] unique_items
cdef long[::1] inverse_users
cdef long[::1] inverse_items
cdef int no_features
cdef int no_users

no_examples = predictions.shape[0]

if precompute:
unique_users, inverse_users = np.unique(user_ids, return_inverse=True)
unique_items, inverse_items = np.unique(item_ids, return_inverse=True)
no_features = unique_items.shape[0]
no_users = unique_users.shape[0]

user_reprs = <flt *>malloc(sizeof(flt) * no_users * (lightfm.no_components + 1))
it_reprs = <flt *>malloc(sizeof(flt) * no_features *(lightfm.no_components + 1))
precompute_unique(item_features,
user_features,
unique_users,
unique_items,
user_reprs,
it_reprs,
lightfm,
num_threads)

{nogil_block}
user_repr = <flt *>malloc(sizeof(flt) * (lightfm.no_components + 1))
it_repr = <flt *>malloc(sizeof(flt) * (lightfm.no_components + 1))
for i in {range_block}(no_examples):
if precompute:
for j in {range_block}(lightfm.no_components + 1):
user_repr[j] = user_reprs[inverse_users[i] * (lightfm.no_components + 1) + j]
it_repr[j] = it_reprs[inverse_items[i] * (lightfm.no_components + 1) + j]
else:
compute_representation(user_features,
lightfm.user_features,
lightfm.user_biases,
lightfm,
user_ids[i],
lightfm.user_scale,
user_repr)
compute_representation(item_features,
lightfm.item_features,
lightfm.item_biases,
lightfm,
item_ids[i],
lightfm.item_scale,
it_repr)

predictions[i] = compute_prediction_from_repr(user_repr,
it_repr,
lightfm.no_components)
it_repr,
lightfm.no_components)

free(user_repr)
free(it_repr)
if precompute:
free(user_reprs)
free(it_reprs)


def predict_ranks(CSRMatrix item_features,
Expand Down Expand Up @@ -1341,3 +1419,4 @@ def __test_in_positives(int row, int col, CSRMatrix mat):
return True
else:
return False

Loading

0 comments on commit 5cf8e21

Please sign in to comment.