Skip to content

Commit

Permalink
Added MSE, from Marc Feger's PR 225
Browse files Browse the repository at this point in the history
  • Loading branch information
NicolasHug committed Sep 12, 2019
1 parent 3b02680 commit 281c8eb
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 6 deletions.
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -188,10 +188,10 @@ Contributors
The following persons have contributed to [Surprise](http://surpriselib.com):

caoyi, Олег Демиденко, Charles-Emmanuel Dias, dmamylin, Lauriane Ducasse,
franckjay, Lukas Galke, Pierre-François Gimenez, Zachary Glassman, Nicolas
Hug, Janniks, Doruk Kilitcioglu, Ravi Raju Krishna, Hengji Liu, Maher
Malaeb, Manoj K, Naturale0, nju-luke, Jay Qi, Skywhat, David Stevens, Victor
Wang, Mike Lee Williams, Jay Wong, Chenchen Xu, YaoZh1918.
Marc Feger, franckjay, Lukas Galke, Pierre-François Gimenez, Zachary
Glassman, Nicolas Hug, Janniks, Doruk Kilitcioglu, Ravi Raju Krishna, Hengji
Liu, Maher Malaeb, Manoj K, Naturale0, nju-luke, Jay Qi, Skywhat, David
Stevens, Victor Wang, Mike Lee Williams, Jay Wong, Chenchen Xu, YaoZh1918.

Thanks a lot :) !

Expand Down
35 changes: 35 additions & 0 deletions surprise/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
:nosignatures:
rmse
mse
mae
fcp
"""
Expand Down Expand Up @@ -54,6 +55,40 @@ def rmse(predictions, verbose=True):
return rmse_


def mse(predictions, verbose=True):
"""Compute MSE (Mean Squared Error).
.. math::
\\text{RMSE} = \\frac{1}{|\\hat{R}|} \\sum_{\\hat{r}_{ui} \in
\\hat{R}}(r_{ui} - \\hat{r}_{ui})^2.
Args:
predictions (:obj:`list` of :obj:`Prediction\
<surprise.prediction_algorithms.predictions.Prediction>`):
A list of predictions, as returned by the :meth:`test()
<surprise.prediction_algorithms.algo_base.AlgoBase.test>` method.
verbose: If True, will print computed value. Default is ``True``.
Returns:
The Mean Squared Error of predictions.
Raises:
ValueError: When ``predictions`` is empty.
"""

if not predictions:
raise ValueError('Prediction list is empty.')

mse_ = np.mean([float((true_r - est)**2)
for (_, _, true_r, est, _) in predictions])

if verbose:
print('MSE: {0:1.4f}'.format(mse_))

return mse_


def mae(predictions, verbose=True):
"""Compute MAE (Mean Absolute Error).
Expand Down
18 changes: 17 additions & 1 deletion tests/test_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import pytest

from surprise.accuracy import mae, rmse, fcp
from surprise.accuracy import mae, rmse, fcp, mse


def pred(true_r, est, u0=None):
Expand Down Expand Up @@ -67,3 +67,19 @@ def test_fcp():

with pytest.raises(ValueError):
fcp([])


def test_mse():
"""Tests for the MSE function."""

predictions = [pred(0, 0), pred(1, 1), pred(2, 2), pred(100, 100)]
assert mse(predictions) == 0

predictions = [pred(0, 0), pred(0, 2)]
assert mse(predictions) == ((0 - 2) ** 2) / 2

predictions = [pred(2, 0), pred(3, 4)]
assert mse(predictions) == ((2 - 0) ** 2 + (3 - 4) ** 2) / 2

with pytest.raises(ValueError):
mse([])
3 changes: 2 additions & 1 deletion tests/test_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,10 @@ def test_performances():
algo = NormalPredictor()
tmp_dir = tempfile.mkdtemp() # create tmp dir
with pytest.warns(UserWarning):
performances = evaluate(algo, data, measures=['RmSe', 'Mae'],
performances = evaluate(algo, data, measures=['RmSe', 'Mae', 'mse'],
with_dump=True, dump_dir=tmp_dir, verbose=2)
shutil.rmtree(tmp_dir) # remove tmp dir

assert performances['RMSE'] is performances['rmse']
assert performances['MaE'] is performances['mae']
assert performances['MsE'] is performances['mse']

0 comments on commit 281c8eb

Please sign in to comment.