Skip to content

Commit

Permalink
🚜 refactor: Review fix.
Browse files Browse the repository at this point in the history
  • Loading branch information
taiseidev committed Nov 26, 2024
1 parent 4ed9eb2 commit b338ec4
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 32 deletions.
38 changes: 10 additions & 28 deletions server/internal/auth/repository/repository.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ import (
"context"
"errors"
"fmt"
"time"

"gorm.io/gorm"
"gorm.io/gorm/clause"
)

type AuthRepository struct {
Expand All @@ -18,7 +18,7 @@ type AuthRepository struct {
// NOTE(onishi): multiple interfaces in the future
type IAuthRepository interface {
SaveOrUpdateRefreshToken(ctx context.Context, tx *gorm.DB, model *authModel.RefreshToken) error
DeleteRefreshToken(ctx context.Context, userID uint) error
DeleteRefreshToken(ctx context.Context, x *gorm.DB, userID uint) error
SaveUser(ctx context.Context, tx *gorm.DB, user *userModel.User) error
GetUserByEmail(ctx context.Context, tx *gorm.DB, email string) (*userModel.User, error)
BeginTransaction(ctx context.Context) *gorm.DB
Expand All @@ -34,33 +34,15 @@ func NewAuthRepository(db *gorm.DB) IAuthRepository {
}

func (r *AuthRepository) SaveOrUpdateRefreshToken(ctx context.Context, tx *gorm.DB, model *authModel.RefreshToken) error {
var refreshToken authModel.RefreshToken

result := tx.WithContext(ctx).Where("user_id = ?", model.UserID).First(&refreshToken)
if result.Error != nil && result.Error != gorm.ErrRecordNotFound {
return result.Error
}

if result.RowsAffected > 0 {
// Update existing refresh token
refreshToken.TokenHash = model.TokenHash
refreshToken.ExpiresAt = model.ExpiresAt
refreshToken.UpdatedAt = time.Now()
// 更新処理
if err := tx.WithContext(ctx).Save(&refreshToken).Error; err != nil {
return err
}
} else {
// Create new refresh token
if err := tx.WithContext(ctx).Create(model).Error; err != nil {
return err
}
}

return nil
return tx.WithContext(ctx).
Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "user_id"}},
DoUpdates: clause.AssignmentColumns([]string{"token_hash", "expires_at", "updated_at"}),
}).
Create(model).Error
}

func (r *AuthRepository) DeleteRefreshToken(ctx context.Context, userID uint) error {
func (r *AuthRepository) DeleteRefreshToken(ctx context.Context, x *gorm.DB, userID uint) error {
// userID に基γ₯いてγƒͺフレッシγƒ₯γƒˆγƒΌγ‚―γƒ³γ‚’ε‰Šι™€
result := r.db.WithContext(ctx).Where("user_id = ?", userID).Delete(&authModel.RefreshToken{})

Expand Down Expand Up @@ -105,5 +87,5 @@ func (r *AuthRepository) GetUserByEmail(ctx context.Context, tx *gorm.DB, email

// γƒˆγƒ©γƒ³γ‚Άγ‚―γ‚·γƒ§γƒ³ι–‹ε§‹
func (r *AuthRepository) BeginTransaction(ctx context.Context) *gorm.DB {
return r.db.Begin()
return r.db.WithContext(ctx).Begin()
}
29 changes: 26 additions & 3 deletions server/internal/auth/service/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"camly-api/internal/user/model"
"context"
"errors"
"fmt"

"golang.org/x/crypto/bcrypt"
)
Expand All @@ -24,7 +25,12 @@ func NewAuthService(authRepo authRepository.IAuthRepository) *AuthService {

func (s *AuthService) SignUp(ctx context.Context, user model.User) (util.TokenResponse, error) {
tx := s.authRepo.BeginTransaction(ctx)
defer tx.Rollback()
defer func() {
if r := recover(); r != nil {
tx.Rollback()
panic(r) // re-panic after rollback
}
}()
// パスワードをハッシγƒ₯εŒ–
hash, err := bcrypt.GenerateFromPassword([]byte(user.Password), 10)
if err != nil {
Expand All @@ -51,7 +57,10 @@ func (s *AuthService) SignUp(ctx context.Context, user model.User) (util.TokenRe
return util.TokenResponse{}, err
}

tx.Commit()
if err := tx.Commit(); err != nil {
tx.Rollback()
return util.TokenResponse{}, fmt.Errorf("failed to commit transaction")
}

return tokens, nil
}
Expand Down Expand Up @@ -102,9 +111,23 @@ func (s *AuthService) Logout(ctx context.Context, accessToken string) error {
return errors.New("invalid or expired access token")
}

if err := s.authRepo.DeleteRefreshToken(ctx, userID); err != nil {
tx := s.authRepo.BeginTransaction(ctx)
defer func() {
if r := recover(); r != nil {
tx.Rollback()
panic(r)
}
}()

if err := s.authRepo.DeleteRefreshToken(ctx, tx, userID); err != nil {
tx.Rollback()
return err
}

if err := tx.Commit(); err != nil {
tx.Rollback()
return fmt.Errorf("failed to commit transaction")
}

return nil
}
Loading

0 comments on commit b338ec4

Please sign in to comment.