Skip to content

Commit

Permalink
tests: added tests to the account domain (go-shiori#993)
Browse files Browse the repository at this point in the history
* tests: added tests to account domain

* refactor: ensure expiration comes from same value

* refactor: jwtclaims to model package

* refactor: add testutil.GetValidAccount
  • Loading branch information
fmartingr authored Nov 1, 2024
1 parent 8c35a6b commit 4a58ef0
Show file tree
Hide file tree
Showing 11 changed files with 235 additions and 73 deletions.
37 changes: 20 additions & 17 deletions internal/domains/accounts.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,36 +16,35 @@ type AccountsDomain struct {
deps *dependencies.Dependencies
}

type JWTClaim struct {
jwt.RegisteredClaims

Account *model.Account
}

func (d *AccountsDomain) CheckToken(ctx context.Context, userJWT string) (*model.Account, error) {
token, err := jwt.ParseWithClaims(userJWT, &JWTClaim{}, func(token *jwt.Token) (interface{}, error) {
func (d *AccountsDomain) ParseToken(userJWT string) (*model.JWTClaim, error) {
token, err := jwt.ParseWithClaims(userJWT, &model.JWTClaim{}, func(token *jwt.Token) (interface{}, error) {
// Validate algorithm
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("Unexpected signing method: %v", token.Header["alg"])
}

return d.deps.Config.Http.SecretKey, nil
})
if err != nil {
return nil, errors.Wrap(err, "error parsing token")
}

if claims, ok := token.Claims.(*JWTClaim); ok && token.Valid {
if claims.Account.ID > 0 {
return claims.Account, nil
}
if err != nil {
return nil, err
}
if claims, ok := token.Claims.(*model.JWTClaim); ok && token.Valid {
return claims, nil
}

return nil, fmt.Errorf("error obtaining user from JWT claims")
}

func (d *AccountsDomain) CheckToken(ctx context.Context, userJWT string) (*model.Account, error) {
claims, err := d.ParseToken(userJWT)
if err != nil {
return nil, fmt.Errorf("error parsing token: %w", err)
}

if claims.Account.ID > 0 {
return claims.Account, nil
}
return nil, fmt.Errorf("error obtaining user from JWT claims")
return nil, fmt.Errorf("error obtaining user from JWT claims: %w", err)
}

func (d *AccountsDomain) GetAccountFromCredentials(ctx context.Context, username, password string) (*model.Account, error) {
Expand All @@ -62,6 +61,10 @@ func (d *AccountsDomain) GetAccountFromCredentials(ctx context.Context, username
}

func (d *AccountsDomain) CreateTokenForAccount(account *model.Account, expiration time.Time) (string, error) {
if account == nil {
return "", fmt.Errorf("account is nil")
}

claims := jwt.MapClaims{
"account": account.ToDTO(),
"exp": expiration.UTC().Unix(),
Expand Down
145 changes: 145 additions & 0 deletions internal/domains/accounts_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
package domains_test

import (
"context"
"testing"
"time"

"github.com/go-shiori/shiori/internal/domains"
"github.com/go-shiori/shiori/internal/model"
"github.com/go-shiori/shiori/internal/testutil"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/require"
)

func TestAccountsDomainParseToken(t *testing.T) {
ctx := context.TODO()
logger := logrus.New()
_, deps := testutil.GetTestConfigurationAndDependencies(t, ctx, logger)
domain := domains.NewAccountsDomain(deps)

t.Run("valid token", func(t *testing.T) {
// Create a valid token
token, err := domain.CreateTokenForAccount(
testutil.GetValidAccount(),
time.Now().Add(time.Hour*1),
)
require.NoError(t, err)

claims, err := domain.ParseToken(token)
require.NoError(t, err)
require.NotNil(t, claims)
require.Equal(t, 99, claims.Account.ID)
})

t.Run("invalid token", func(t *testing.T) {
claims, err := domain.ParseToken("invalid-token")
require.Error(t, err)
require.Nil(t, claims)
})
}

func TestAccountsDomainCheckToken(t *testing.T) {
ctx := context.TODO()
logger := logrus.New()
_, deps := testutil.GetTestConfigurationAndDependencies(t, ctx, logger)
domain := domains.NewAccountsDomain(deps)

t.Run("valid token", func(t *testing.T) {
// Create a valid token
token, err := domain.CreateTokenForAccount(
testutil.GetValidAccount(),
time.Now().Add(time.Hour*1),
)
require.NoError(t, err)

acc, err := domain.CheckToken(ctx, token)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, 99, acc.ID)
})

t.Run("expired token", func(t *testing.T) {
// Create an expired token
token, err := domain.CreateTokenForAccount(
testutil.GetValidAccount(),
time.Now().Add(time.Hour*-1),
)
require.NoError(t, err)

acc, err := domain.CheckToken(ctx, token)
require.Error(t, err)
require.Nil(t, acc)
})
}

func TestAccountsDomainGetAccountFromCredentials(t *testing.T) {
ctx := context.TODO()
logger := logrus.New()
_, deps := testutil.GetTestConfigurationAndDependencies(t, ctx, logger)
domain := domains.NewAccountsDomain(deps)

require.NoError(t, deps.Database.SaveAccount(ctx, model.Account{
Username: "test",
Password: "test",
}))

t.Run("valid credentials", func(t *testing.T) {
acc, err := domain.GetAccountFromCredentials(ctx, "test", "test")
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, "test", acc.Username)
})

t.Run("invalid credentials", func(t *testing.T) {
acc, err := domain.GetAccountFromCredentials(ctx, "test", "invalid")
require.Error(t, err)
require.Nil(t, acc)
})

t.Run("invalid username", func(t *testing.T) {
acc, err := domain.GetAccountFromCredentials(ctx, "nope", "invalid")
require.Error(t, err)
require.Nil(t, acc)
})

}

func TestAccountsDomainCreateTokenForAccount(t *testing.T) {
ctx := context.TODO()
logger := logrus.New()
_, deps := testutil.GetTestConfigurationAndDependencies(t, ctx, logger)
domain := domains.NewAccountsDomain(deps)

t.Run("valid account", func(t *testing.T) {
token, err := domain.CreateTokenForAccount(
testutil.GetValidAccount(),
time.Now().Add(time.Hour*1),
)
require.NoError(t, err)
require.NotEmpty(t, token)
})

t.Run("nil account", func(t *testing.T) {
token, err := domain.CreateTokenForAccount(
nil,
time.Now().Add(time.Hour*1),
)
require.Error(t, err)
require.Empty(t, token)
})

t.Run("token expiration is valid", func(t *testing.T) {
expiration := time.Now().Add(time.Hour * 9)
token, err := domain.CreateTokenForAccount(
testutil.GetValidAccount(),
expiration,
)
require.NoError(t, err)
require.NotEmpty(t, token)
claims, err := domain.ParseToken(token)
require.NoError(t, err)
require.NotNil(t, claims)
require.Equal(t, expiration.Unix(), claims.ExpiresAt.Time.Unix())
})
}
1 change: 1 addition & 0 deletions internal/http/middleware/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ func AuthMiddleware(deps *dependencies.Dependencies) gin.HandlerFunc {

account, err := deps.Domains.Auth.CheckToken(c, token)
if err != nil {
deps.Log.WithError(err).Error("Failed to check token")
return
}

Expand Down
8 changes: 4 additions & 4 deletions internal/http/middleware/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ func TestAuthMiddleware(t *testing.T) {
})

t.Run("test authorization header", func(t *testing.T) {
account := model.Account{Username: "shiori"}
token, err := deps.Domains.Auth.CreateTokenForAccount(&account, time.Now().Add(time.Minute))
account := testutil.GetValidAccount()
token, err := deps.Domains.Auth.CreateTokenForAccount(account, time.Now().Add(time.Minute))
require.NoError(t, err)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
Expand All @@ -74,8 +74,8 @@ func TestAuthMiddleware(t *testing.T) {
})

t.Run("test authorization cookie", func(t *testing.T) {
account := model.Account{Username: "shiori"}
token, err := deps.Domains.Auth.CreateTokenForAccount(&account, time.Now().Add(time.Minute))
account := testutil.GetValidAccount()
token, err := deps.Domains.Auth.CreateTokenForAccount(account, time.Now().Add(time.Minute))
require.NoError(t, err)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
Expand Down
22 changes: 14 additions & 8 deletions internal/http/routes/api/v1/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ func (r *AuthAPIRoutes) Setup(group *gin.RouterGroup) model.Routes {
group.GET("/me", r.meHandler)
group.POST("/login", r.loginHandler)
group.POST("/refresh", r.refreshHandler)
group.PATCH("/account", r.settingsHandler)
group.PATCH("/account", r.updateHandler)
return r
}

Expand Down Expand Up @@ -81,18 +81,20 @@ func (r *AuthAPIRoutes) loginHandler(c *gin.Context) {
return
}

expiration := time.Now().Add(time.Hour)
expiration := time.Hour
if payload.RememberMe {
expiration = time.Now().Add(time.Hour * 24 * 30)
expiration = time.Hour * 24 * 30
}

token, err := r.deps.Domains.Auth.CreateTokenForAccount(account, expiration)
expirationTime := time.Now().Add(expiration)

token, err := r.deps.Domains.Auth.CreateTokenForAccount(account, expirationTime)
if err != nil {
response.SendInternalServerError(c)
return
}

sessionID, err := r.legacyLoginHandler(*account, time.Hour*24*30)
sessionID, err := r.legacyLoginHandler(*account, expiration)
if err != nil {
r.logger.WithError(err).Error("failed execute legacy login handler")
response.SendInternalServerError(c)
Expand All @@ -102,7 +104,7 @@ func (r *AuthAPIRoutes) loginHandler(c *gin.Context) {
response.Send(c, http.StatusOK, loginResponseMessage{
Token: token,
SessionID: sessionID,
Expiration: expiration.Unix(),
Expiration: expirationTime.Unix(),
})
}

Expand Down Expand Up @@ -154,7 +156,7 @@ func (r *AuthAPIRoutes) meHandler(c *gin.Context) {
response.Send(c, http.StatusOK, ctx.GetAccount())
}

// settingsHandler godoc
// updateHandler godoc
//
// @Summary Perform actions on the currently logged-in user.
// @Tags Auth
Expand All @@ -164,7 +166,7 @@ func (r *AuthAPIRoutes) meHandler(c *gin.Context) {
// @Success 200 {object} model.Account
// @Failure 403 {object} nil "Token not provided/invalid"
// @Router /api/v1/auth/account [patch]
func (r *AuthAPIRoutes) settingsHandler(c *gin.Context) {
func (r *AuthAPIRoutes) updateHandler(c *gin.Context) {
ctx := context.NewContextFromGin(c)
if !ctx.UserIsLogged() {
response.SendError(c, http.StatusForbidden, nil)
Expand All @@ -175,6 +177,10 @@ func (r *AuthAPIRoutes) settingsHandler(c *gin.Context) {
}

account := ctx.GetAccount()
if account == nil {
response.SendError(c, http.StatusUnauthorized, nil)
return
}
account.Config = payload.Config

err := r.deps.Database.SaveAccountSettings(c, *account)
Expand Down
35 changes: 18 additions & 17 deletions internal/http/routes/api/v1/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,14 +82,11 @@ func TestAccountsRoute(t *testing.T) {
router.Setup(g.Group("/"))

// Create an account manually to test
account := model.Account{
Username: "shiori",
Password: "gopher",
Owner: true,
}
require.NoError(t, deps.Database.SaveAccount(ctx, account))
account := testutil.GetValidAccount()
account.Owner = true
require.NoError(t, deps.Database.SaveAccount(ctx, *account))

token, err := deps.Domains.Auth.CreateTokenForAccount(&account, time.Now().Add(time.Minute))
token, err := deps.Domains.Auth.CreateTokenForAccount(account, time.Now().Add(time.Minute))
require.NoError(t, err)

req := httptest.NewRequest("GET", "/me", nil)
Expand Down Expand Up @@ -175,9 +172,7 @@ func TestRefreshHandler(t *testing.T) {
})

t.Run("token valid", func(t *testing.T) {
token, err := deps.Domains.Auth.CreateTokenForAccount(&model.Account{
Username: "shiori",
}, time.Now().Add(time.Minute))
token, err := deps.Domains.Auth.CreateTokenForAccount(testutil.GetValidAccount(), time.Now().Add(time.Minute))
require.NoError(t, err)

w := testutil.PerformRequest(g, "POST", "/refresh", testutil.WithHeader(model.AuthorizationHeader, model.AuthorizationTokenType+" "+token))
Expand All @@ -186,7 +181,7 @@ func TestRefreshHandler(t *testing.T) {
})
}

func TestSettingsHandler(t *testing.T) {
func TestUpdateHandler(t *testing.T) {
logger := logrus.New()
ctx := context.TODO()
g := testutil.NewGin()
Expand All @@ -196,10 +191,18 @@ func TestSettingsHandler(t *testing.T) {
g.Use(middleware.AuthMiddleware(deps))
router.Setup(g.Group("/"))

require.NoError(t, deps.Database.SaveAccount(ctx, model.Account{
Username: "shiori",
Password: "gopher",
}))

t.Run("invalid token", func(t *testing.T) {
w := testutil.PerformRequest(g, "PATCH", "/account")
require.Equal(t, http.StatusForbidden, w.Code)
})

t.Run("token valid", func(t *testing.T) {
token, err := deps.Domains.Auth.CreateTokenForAccount(&model.Account{
Username: "shiori",
}, time.Now().Add(time.Minute))
token, err := deps.Domains.Auth.CreateTokenForAccount(testutil.GetValidAccount(), time.Now().Add(time.Minute))
require.NoError(t, err)

type settingRequestPayload struct {
Expand All @@ -222,9 +225,7 @@ func TestSettingsHandler(t *testing.T) {
})

t.Run("config not valid", func(t *testing.T) {
token, err := deps.Domains.Auth.CreateTokenForAccount(&model.Account{
Username: "shiori",
}, time.Now().Add(time.Minute))
token, err := deps.Domains.Auth.CreateTokenForAccount(testutil.GetValidAccount(), time.Now().Add(time.Minute))
require.NoError(t, err)

w := testutil.PerformRequest(g, "PATCH", "/account", testutil.WithBody("notValidConfig"), testutil.WithHeader(model.AuthorizationHeader, model.AuthorizationTokenType+" "+token))
Expand Down
Loading

0 comments on commit 4a58ef0

Please sign in to comment.