forked from NicolasHug/Surprise
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathk_nearest_neighbors.py
60 lines (47 loc) · 2.04 KB
/
k_nearest_neighbors.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
"""
This module illustrates how to retrieve the k-nearest neighbors of an item. The
same can be done for users with minor changes. There's a lot of boilerplate
because of the id conversions, but it all boils down to the use of
algo.get_neighbors().
"""
from __future__ import (absolute_import, division, print_function,
unicode_literals)
import io # needed because of weird encoding of u.item file
from surprise import KNNBaseline
from surprise import Dataset
from surprise import get_dataset_dir
def read_item_names():
"""Read the u.item file from MovieLens 100-k dataset and return two
mappings to convert raw ids into movie names and movie names into raw ids.
"""
file_name = get_dataset_dir() + '/ml-100k/ml-100k/u.item'
rid_to_name = {}
name_to_rid = {}
with io.open(file_name, 'r', encoding='ISO-8859-1') as f:
for line in f:
line = line.split('|')
rid_to_name[line[0]] = line[1]
name_to_rid[line[1]] = line[0]
return rid_to_name, name_to_rid
# First, train the algortihm to compute the similarities between items
data = Dataset.load_builtin('ml-100k')
trainset = data.build_full_trainset()
sim_options = {'name': 'pearson_baseline', 'user_based': False}
algo = KNNBaseline(sim_options=sim_options)
algo.fit(trainset)
# Read the mappings raw id <-> movie name
rid_to_name, name_to_rid = read_item_names()
# Retrieve inner id of the movie Toy Story
toy_story_raw_id = name_to_rid['Toy Story (1995)']
toy_story_inner_id = algo.trainset.to_inner_iid(toy_story_raw_id)
# Retrieve inner ids of the nearest neighbors of Toy Story.
toy_story_neighbors = algo.get_neighbors(toy_story_inner_id, k=10)
# Convert inner ids of the neighbors into names.
toy_story_neighbors = (algo.trainset.to_raw_iid(inner_id)
for inner_id in toy_story_neighbors)
toy_story_neighbors = (rid_to_name[rid]
for rid in toy_story_neighbors)
print()
print('The 10 nearest neighbors of Toy Story are:')
for movie in toy_story_neighbors:
print(movie)