forked from NicolasHug/Surprise
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_SVD.py
165 lines (129 loc) · 6.08 KB
/
test_SVD.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
"""
Module for testing the SVD and SVD++ algorithms.
"""
from __future__ import (absolute_import, division, print_function,
unicode_literals)
from surprise import SVD
from surprise import SVDpp
from surprise.model_selection import cross_validate
def test_SVD_parameters(u1_ml100k, pkf):
"""Ensure that all parameters are taken into account."""
# The baseline against which to compare.
algo = SVD(n_factors=1, n_epochs=1, random_state=1)
rmse_default = cross_validate(algo, u1_ml100k, ['rmse'], pkf)['test_rmse']
# n_factors
algo = SVD(n_factors=2, n_epochs=1, random_state=1)
rmse_factors = cross_validate(algo, u1_ml100k, ['rmse'], pkf)['test_rmse']
assert rmse_default != rmse_factors
# n_epochs
algo = SVD(n_factors=1, n_epochs=2, random_state=1)
rmse_n_epochs = cross_validate(algo, u1_ml100k, ['rmse'], pkf)['test_rmse']
assert rmse_default != rmse_n_epochs
# biased
algo = SVD(n_factors=1, n_epochs=1, biased=False, random_state=1)
rmse_biased = cross_validate(algo, u1_ml100k, ['rmse'], pkf)['test_rmse']
assert rmse_default != rmse_biased
# lr_all
algo = SVD(n_factors=1, n_epochs=1, lr_all=5, random_state=1)
rmse_lr_all = cross_validate(algo, u1_ml100k, ['rmse'], pkf)['test_rmse']
assert rmse_default != rmse_lr_all
# reg_all
algo = SVD(n_factors=1, n_epochs=1, reg_all=5, random_state=1)
rmse_reg_all = cross_validate(algo, u1_ml100k, ['rmse'], pkf)['test_rmse']
assert rmse_default != rmse_reg_all
# lr_bu
algo = SVD(n_factors=1, n_epochs=1, lr_bu=5, random_state=1)
rmse_lr_bu = cross_validate(algo, u1_ml100k, ['rmse'], pkf)['test_rmse']
assert rmse_default != rmse_lr_bu
# lr_bi
algo = SVD(n_factors=1, n_epochs=1, lr_bi=5, random_state=1)
rmse_lr_bi = cross_validate(algo, u1_ml100k, ['rmse'], pkf)['test_rmse']
assert rmse_default != rmse_lr_bi
# lr_pu
algo = SVD(n_factors=1, n_epochs=1, lr_pu=5, random_state=1)
rmse_lr_pu = cross_validate(algo, u1_ml100k, ['rmse'], pkf)['test_rmse']
assert rmse_default != rmse_lr_pu
# lr_qi
algo = SVD(n_factors=1, n_epochs=1, lr_qi=5, random_state=1)
rmse_lr_qi = cross_validate(algo, u1_ml100k, ['rmse'], pkf)['test_rmse']
assert rmse_default != rmse_lr_qi
# reg_bu
algo = SVD(n_factors=1, n_epochs=1, reg_bu=5, random_state=1)
rmse_reg_bu = cross_validate(algo, u1_ml100k, ['rmse'], pkf)['test_rmse']
assert rmse_default != rmse_reg_bu
# reg_bi
algo = SVD(n_factors=1, n_epochs=1, reg_bi=5, random_state=1)
rmse_reg_bi = cross_validate(algo, u1_ml100k, ['rmse'], pkf)['test_rmse']
assert rmse_default != rmse_reg_bi
# reg_pu
algo = SVD(n_factors=1, n_epochs=1, reg_pu=5, random_state=1)
rmse_reg_pu = cross_validate(algo, u1_ml100k, ['rmse'], pkf)['test_rmse']
assert rmse_default != rmse_reg_pu
# reg_qi
algo = SVD(n_factors=1, n_epochs=1, reg_qi=5, random_state=1)
rmse_reg_qi = cross_validate(algo, u1_ml100k, ['rmse'], pkf)['test_rmse']
assert rmse_default != rmse_reg_qi
def test_SVDpp_parameters(u1_ml100k, pkf):
"""Ensure that all parameters are taken into account."""
# The baseline against which to compare.
algo = SVDpp(n_factors=1, n_epochs=1, random_state=1)
rmse_default = cross_validate(algo, u1_ml100k, ['rmse'], pkf)['test_rmse']
# n_factors
algo = SVDpp(n_factors=2, n_epochs=1, random_state=1)
rmse_factors = cross_validate(algo, u1_ml100k, ['rmse'], pkf)['test_rmse']
assert rmse_default != rmse_factors
# The rest is OK but just takes too long for now...
"""
# n_epochs
algo = SVDpp(n_factors=1, n_epochs=2, random_state=1)
rmse_n_epochs = cross_validate(algo, u1_ml100k, ['rmse'], pkf)['test_rmse']
assert rmse_default != rmse_n_epochs
# lr_all
algo = SVDpp(n_factors=1, n_epochs=1, lr_all=5, random_state=1)
rmse_lr_all = cross_validate(algo, u1_ml100k, ['rmse'], pkf)['test_rmse']
assert rmse_default != rmse_lr_all
# reg_all
algo = SVDpp(n_factors=1, n_epochs=1, reg_all=5, random_state=1)
rmse_reg_all = cross_validate(algo, u1_ml100k, ['rmse'], pkf)['test_rmse']
assert rmse_default != rmse_reg_all
# lr_bu
algo = SVDpp(n_factors=1, n_epochs=1, lr_bu=5, random_state=1)
rmse_lr_bu = cross_validate(algo, u1_ml100k, ['rmse'], pkf)['test_rmse']
assert rmse_default != rmse_lr_bu
# lr_bi
algo = SVDpp(n_factors=1, n_epochs=1, lr_bi=5, random_state=1)
rmse_lr_bi = cross_validate(algo, u1_ml100k, ['rmse'], pkf)['test_rmse']
assert rmse_default != rmse_lr_bi
# lr_pu
algo = SVDpp(n_factors=1, n_epochs=1, lr_pu=5, random_state=1)
rmse_lr_pu = cross_validate(algo, u1_ml100k, ['rmse'], pkf)['test_rmse']
assert rmse_default != rmse_lr_pu
# lr_qi
algo = SVDpp(n_factors=1, n_epochs=1, lr_qi=5, random_state=1)
rmse_lr_qi = cross_validate(algo, u1_ml100k, ['rmse'], pkf)['test_rmse']
assert rmse_default != rmse_lr_qi
# lr_yj
algo = SVDpp(n_factors=1, n_epochs=1, lr_yj=5, random_state=1)
rmse_lr_yj = cross_validate(algo, u1_ml100k, ['rmse'], pkf)['test_rmse']
assert rmse_default != rmse_lr_yj
# reg_bu
algo = SVDpp(n_factors=1, n_epochs=1, reg_bu=5, random_state=1)
rmse_reg_bu = cross_validate(algo, u1_ml100k, ['rmse'], pkf)['test_rmse']
assert rmse_default != rmse_reg_bu
# reg_bi
algo = SVDpp(n_factors=1, n_epochs=1, reg_bi=5, random_state=1)
rmse_reg_bi = cross_validate(algo, u1_ml100k, ['rmse'], pkf)['test_rmse']
assert rmse_default != rmse_reg_bi
# reg_pu
algo = SVDpp(n_factors=1, n_epochs=1, reg_pu=5, random_state=1)
rmse_reg_pu = cross_validate(algo, u1_ml100k, ['rmse'], pkf)['test_rmse']
assert rmse_default != rmse_reg_pu
# reg_qi
algo = SVDpp(n_factors=1, n_epochs=1, reg_qi=5, random_state=1)
rmse_reg_qi = cross_validate(algo, u1_ml100k, ['rmse'], pkf)['test_rmse']
assert rmse_default != rmse_reg_qi
# reg_yj
algo = SVDpp(n_factors=1, n_epochs=1, reg_yj=5, random_state=1)
rmse_reg_yj = cross_validate(algo, u1_ml100k, ['rmse'], pkf)['test_rmse']
assert rmse_default != rmse_reg_yj
"""