diff --git a/opencve/extensions.py b/opencve/extensions.py index de79b618..a81e7080 100644 --- a/opencve/extensions.py +++ b/opencve/extensions.py @@ -1,11 +1,12 @@ from celery import Celery from flask_admin import Admin from flask_debugtoolbar import DebugToolbarExtension +from flask_login import current_user from flask_gravatar import Gravatar from flask_migrate import Migrate from flask_sqlalchemy import SQLAlchemy from flask_user import UserManager, EmailManager -from flask_user.forms import EditUserProfileForm, RegisterForm +from flask_user.forms import EditUserProfileForm, RegisterForm, unique_email_validator from flask_wtf import RecaptchaField from flask_wtf.csrf import CSRFProtect from wtforms import validators, StringField @@ -19,10 +20,26 @@ class CustomUserManager(UserManager): """ def customize(self, app): - # Add the email field + def _unique_email_validator(form, field): + """ + Check if the new email is unique. Skip this step if the + email is the same as the current one. + """ + if field.data.lower() == current_user.email.lower(): + return + unique_email_validator(form, field) + + # Add the email field and make first and last names as not required class CustomUserProfileForm(EditUserProfileForm): + first_name = StringField("First name") + last_name = StringField("Last name") email = StringField( - "Email", validators=[validators.DataRequired(), validators.Email()] + "Email", + validators=[ + validators.DataRequired(), + validators.Email(), + _unique_email_validator, + ], ) self.EditUserProfileFormClass = CustomUserProfileForm diff --git a/tests/conftest.py b/tests/conftest.py index 5c7804a5..3a9ccf48 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -12,6 +12,7 @@ import pytest from bs4 import BeautifulSoup +from flask import url_for from opencve import create_app from opencve.commands.utils import CveUtil @@ -151,3 +152,17 @@ def _get_cve_names(soup): ] return _get_cve_names + + +@pytest.fixture(scope="function") +def login(create_user, client): + create_user() + client.post( + url_for("user.login"), + data={"username": "user", "password": "password"}, + follow_redirects=True, + ) + + yield + + client.post(url_for("user.logout"), follow_redirects=True) diff --git a/tests/controllers/test_profile.py b/tests/controllers/test_profile.py new file mode 100644 index 00000000..cc6ede1b --- /dev/null +++ b/tests/controllers/test_profile.py @@ -0,0 +1,55 @@ +import pytest +from flask import request + +from opencve.models.users import User + + +@pytest.mark.parametrize( + "first_name,last_name,email", + [ + ("john", "doe", "john.doe@example.com"), + ("john", "", "john.doe@example.com"), + ("", "doe", "john.doe@example.com"), + ("", "", "john.doe@example.com"), + ("", "", "user@opencve.io"), + ], +) +def test_edit_profile(login, client, first_name, last_name, email): + user = User.query.first() + assert user.first_name == "" + assert user.last_name == "" + assert user.email == "user@opencve.io" + + client.post( + "/account/profile", + data={"first_name": first_name, "last_name": last_name, "email": email}, + follow_redirects=True, + ) + + user = User.query.first() + assert user.first_name == first_name + assert user.last_name == last_name + assert user.email == email + + +def test_edit_profile_email_required(login, client): + response = client.post( + "/account/profile", data={"email": ""}, follow_redirects=True + ) + assert b"This field is required." in response.data + + +def test_edit_profile_with_existing_email(login, create_user, client): + user = User.query.filter_by(username="user").first() + assert user.email == "user@opencve.io" + + create_user("user2") + + response = client.post( + "/account/profile", data={"email": "user2@opencve.io"}, follow_redirects=True + ) + assert b"This Email is already in use. Please try another one" in response.data + + # Email has not changed + user = User.query.filter_by(username="user").first() + assert user.email == "user@opencve.io"