Skip to content

Commit

Permalink
feat: add store for access token cache
Browse files Browse the repository at this point in the history
  • Loading branch information
boojack committed Aug 26, 2024
1 parent 4b2e50c commit 5d7dc80
Show file tree
Hide file tree
Showing 4 changed files with 186 additions and 54 deletions.
8 changes: 8 additions & 0 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package memogram

import (
"os"
"path"

"github.com/caarlos0/env"
"github.com/joho/godotenv"
"github.com/pkg/errors"
Expand All @@ -10,6 +12,7 @@ import (
type Config struct {
ServerAddr string `env:"SERVER_ADDR,required"`
BotToken string `env:"BOT_TOKEN,required"`
Data string `env:"DATA"`
}

func getConfigFromEnv() (*Config, error) {
Expand All @@ -25,5 +28,10 @@ func getConfigFromEnv() (*Config, error) {
if err := env.Parse(&config); err != nil {
return nil, errors.Wrap(err, "invalid configuration")
}
if config.Data == "" {
// Default to `data.txt` if not specified.
config.Data = "data.txt"
}
config.Data = path.Join(".", config.Data)
return &config, nil
}
103 changes: 49 additions & 54 deletions memogram.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,28 +8,23 @@ import (
"net/http"
"path/filepath"
"strings"
"sync"

"github.com/go-telegram/bot"
"github.com/go-telegram/bot/models"
"github.com/pkg/errors"
"github.com/usememos/memogram/store"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/metadata"
fieldmaskpb "google.golang.org/protobuf/types/known/fieldmaskpb"
)

// userAccessTokenCache is a cache for user access token.
// Key is the user id from telegram.
// Value is the access token from memos.
// TODO: save it to a persistent storage.
var userAccessTokenCache sync.Map // map[int64]string

type Service struct {
config *Config
client *MemosClient
bot *bot.Bot
client *MemosClient
config *Config
store *store.Store
}

func NewService() (*Service, error) {
Expand All @@ -38,16 +33,21 @@ func NewService() (*Service, error) {
return nil, errors.Wrap(err, "failed to get config from env")
}

conn, err := grpc.Dial(config.ServerAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
conn, err := grpc.NewClient(config.ServerAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
slog.Error("failed to connect to server", slog.Any("err", err))
return nil, errors.Wrap(err, "failed to connect to server")
}
client := NewMemosClient(conn)

store := store.NewStore(config.Data)
if err := store.Init(); err != nil {
return nil, errors.Wrap(err, "failed to init store")
}
s := &Service{
config: config,
client: client,
store: store,
}

opts := []bot.Option{
Expand Down Expand Up @@ -86,7 +86,7 @@ func (s *Service) handler(ctx context.Context, b *bot.Bot, m *models.Update) {
}

userID := m.Message.From.ID
if _, ok := userAccessTokenCache.Load(userID); !ok {
if _, ok := s.store.GetUserAccessToken(userID); !ok {
b.SendMessage(ctx, &bot.SendMessageParams{
ChatID: m.Message.Chat.ID,
Text: "Please start the bot with /start <access_token>",
Expand Down Expand Up @@ -147,8 +147,8 @@ func (s *Service) handler(ctx context.Context, b *bot.Bot, m *models.Update) {
return
}

accessToken, _ := userAccessTokenCache.Load(userID)
ctx = metadata.NewOutgoingContext(ctx, metadata.Pairs("Authorization", fmt.Sprintf("Bearer %s", accessToken.(string))))
accessToken, _ := s.store.GetUserAccessToken(userID)
ctx = metadata.NewOutgoingContext(ctx, metadata.Pairs("Authorization", fmt.Sprintf("Bearer %s", accessToken)))
memo, err := s.client.MemoService.CreateMemo(ctx, &v1pb.CreateMemoRequest{
Content: content,
})
Expand All @@ -164,15 +164,12 @@ func (s *Service) handler(ctx context.Context, b *bot.Bot, m *models.Update) {
if message.Document != nil {
s.processFileMessage(ctx, b, m, message.Document.FileID, memo)
}

if message.Voice != nil {
s.processFileMessage(ctx, b, m, message.Voice.FileID, memo)
}

if message.Video != nil {
s.processFileMessage(ctx, b, m, message.Video.FileID, memo)
}

if len(message.Photo) > 0 {
photo := message.Photo[len(message.Photo)-1]
s.processFileMessage(ctx, b, m, photo.FileID, memo)
Expand Down Expand Up @@ -204,7 +201,7 @@ func (s *Service) startHandler(ctx context.Context, b *bot.Bot, m *models.Update
return
}

userAccessTokenCache.Store(userID, accessToken)
s.store.SetUserAccessToken(userID, accessToken)
b.SendMessage(ctx, &bot.SendMessageParams{
ChatID: m.Message.Chat.ID,
Text: fmt.Sprintf("Hello %s!", user.Nickname),
Expand All @@ -214,29 +211,29 @@ func (s *Service) startHandler(ctx context.Context, b *bot.Bot, m *models.Update
func (s *Service) keyboard(memo *v1pb.Memo) *models.InlineKeyboardMarkup {
// add inline keyboard to edit memo's visibility or pinned status.
return &models.InlineKeyboardMarkup{
InlineKeyboard: [][]models.InlineKeyboardButton{
InlineKeyboard: [][]models.InlineKeyboardButton{
{
{
Text: "Public",
CallbackData: fmt.Sprintf("public %s", memo.Name),
},
{
{
Text: "Public",
CallbackData: fmt.Sprintf("public %s", memo.Name),
},
{
Text: "Private",
CallbackData: fmt.Sprintf("private %s", memo.Name),
},
{
Text: "Pin",
CallbackData: fmt.Sprintf("pin %s", memo.Name),
},
Text: "Private",
CallbackData: fmt.Sprintf("private %s", memo.Name),
},
{
Text: "Pin",
CallbackData: fmt.Sprintf("pin %s", memo.Name),
},
},
}
},
}
}

func (s *Service) callbackQueryHandler(ctx context.Context, b *bot.Bot, update *models.Update) {
callbackData := update.CallbackQuery.Data
userID := update.CallbackQuery.From.ID
accessToken, ok := userAccessTokenCache.Load(userID)
accessToken, ok := s.store.GetUserAccessToken(userID)
if !ok {
b.AnswerCallbackQuery(ctx, &bot.AnswerCallbackQueryParams{
CallbackQueryID: update.CallbackQuery.ID,
Expand All @@ -246,7 +243,7 @@ func (s *Service) callbackQueryHandler(ctx context.Context, b *bot.Bot, update *
return
}

ctx = metadata.NewOutgoingContext(ctx, metadata.Pairs("Authorization", fmt.Sprintf("Bearer %s", accessToken.(string))))
ctx = metadata.NewOutgoingContext(ctx, metadata.Pairs("Authorization", fmt.Sprintf("Bearer %s", accessToken)))

parts := strings.Split(callbackData, " ")
if len(parts) != 2 {
Expand Down Expand Up @@ -313,10 +310,10 @@ func (s *Service) callbackQueryHandler(ctx context.Context, b *bot.Bot, update *
pinnedMarker = ""
}
b.EditMessageText(ctx, &bot.EditMessageTextParams{
ChatID: update.CallbackQuery.Message.Message.Chat.ID,
MessageID: update.CallbackQuery.Message.Message.ID,
Text: fmt.Sprintf("Memo updated as %s with [%s](%s/m/%s) %s", v1pb.Visibility_name[int32(memo.Visibility)], memo.Name, s.config.ServerAddr, memo.Uid, pinnedMarker),
ParseMode: models.ParseModeMarkdown,
ChatID: update.CallbackQuery.Message.Message.Chat.ID,
MessageID: update.CallbackQuery.Message.Message.ID,
Text: fmt.Sprintf("Memo updated as %s with [%s](%s/m/%s) %s", v1pb.Visibility_name[int32(memo.Visibility)], memo.Name, s.config.ServerAddr, memo.Uid, pinnedMarker),
ParseMode: models.ParseModeMarkdown,
ReplyMarkup: s.keyboard(memo),
})

Expand All @@ -332,8 +329,8 @@ func (s *Service) searchHandler(ctx context.Context, b *bot.Bot, m *models.Updat

filterString := "content_search == ['" + searchString + "']"

accessToken, _ := userAccessTokenCache.Load(userID)
ctx = metadata.NewOutgoingContext(ctx, metadata.Pairs("Authorization", fmt.Sprintf("Bearer %s", accessToken.(string))))
accessToken, _ := s.store.GetUserAccessToken(userID)
ctx = metadata.NewOutgoingContext(ctx, metadata.Pairs("Authorization", fmt.Sprintf("Bearer %s", accessToken)))
results, err := s.client.MemoService.ListMemos(ctx, &v1pb.ListMemosRequest{
PageSize: 10,
Filter: filterString,
Expand All @@ -360,8 +357,20 @@ func (s *Service) searchHandler(ctx context.Context, b *bot.Bot, m *models.Updat
})
}
}
}

func (s *Service) processFileMessage(ctx context.Context, b *bot.Bot, m *models.Update, fileID string, memo *v1pb.Memo) {
file, err := b.GetFile(ctx, &bot.GetFileParams{FileID: fileID})
if err != nil {
s.sendError(b, m.Message.Chat.ID, errors.Wrap(err, "failed to get file"))
return
}

return
_, err = s.saveResourceFromFile(ctx, file, memo)
if err != nil {
s.sendError(b, m.Message.Chat.ID, errors.Wrap(err, "failed to save resource"))
return
}
}

func (s *Service) saveResourceFromFile(ctx context.Context, file *models.File, memo *v1pb.Memo) (*v1pb.Resource, error) {
Expand Down Expand Up @@ -397,20 +406,6 @@ func (s *Service) saveResourceFromFile(ctx context.Context, file *models.File, m
return resource, nil
}

func (s *Service) processFileMessage(ctx context.Context, b *bot.Bot, m *models.Update, fileID string, memo *v1pb.Memo) {
file, err := b.GetFile(ctx, &bot.GetFileParams{FileID: fileID})
if err != nil {
s.sendError(b, m.Message.Chat.ID, errors.Wrap(err, "failed to get file"))
return
}

_, err = s.saveResourceFromFile(ctx, file, memo)
if err != nil {
s.sendError(b, m.Message.Chat.ID, errors.Wrap(err, "failed to save resource"))
return
}
}

func (s *Service) sendError(b *bot.Bot, chatID int64, err error) {
slog.Error("error", slog.Any("err", err))
b.SendMessage(context.Background(), &bot.SendMessageParams{
Expand Down
29 changes: 29 additions & 0 deletions store/store.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package store

import (
"sync"

"github.com/pkg/errors"
)

type Store struct {
Data string

userAccessTokenCache sync.Map // map[int64]string
}

func NewStore(data string) *Store {
return &Store{
Data: data,

userAccessTokenCache: sync.Map{},
}
}

func (s *Store) Init() error {
if err := s.loadUserAccessTokenMapFromFile(); err != nil {
return errors.Wrap(err, "failed to load user access token map from file")
}

return nil
}
100 changes: 100 additions & 0 deletions store/user.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
package store

import (
"bufio"
"log/slog"
"os"
"strconv"
"strings"
)

// GetUserAccessToken returns the access token for the user.
func (s *Store) GetUserAccessToken(userID int64) (string, bool) {
accessToken, ok := s.userAccessTokenCache.Load(userID)
if !ok {
return "", false
}
return accessToken.(string), true
}

// SetUserAccessToken sets the access token for the user.
func (s *Store) SetUserAccessToken(userID int64, accessToken string) {
s.userAccessTokenCache.Store(userID, accessToken)
if err := s.SaveUserAccessTokenMapToFile(); err != nil {
slog.Error("failed to save user access token map to file", "error", err)
}
}

// SaveUserAccessTokenMapToFile saves the user access token map to a data file.
func (s *Store) SaveUserAccessTokenMapToFile() error {
// Open the file for writing
file, err := os.OpenFile(s.Data, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
if err != nil {
return err
}
defer file.Close()

// Iterate over the user access token map and write each entry to the file
s.userAccessTokenCache.Range(func(key, value interface{}) bool {
userID := key.(int64)
accessToken := value.(string)
line := strconv.FormatInt(userID, 10) + ":" + accessToken + "\n"
_, err := file.WriteString(line)
if err != nil {
return false
}
return true
})

return nil
}

func (s *Store) loadUserAccessTokenMapFromFile() error {
// Check if the file exists
if _, err := os.Stat(s.Data); os.IsNotExist(err) {
// Create the file if it doesn't exist
file, err := os.Create(s.Data)
if err != nil {
return err
}
defer file.Close()
}

// Open the file
file, err := os.Open(s.Data)
if err != nil {
return err
}
defer file.Close()

// Read the file line by line
scanner := bufio.NewScanner(file)
for scanner.Scan() {
line := scanner.Text()
// Parse the line and extract the user ID and access token
userID, accessToken := parseLine(line)
if userID == 0 || accessToken == "" {
continue
}
// Store the user ID and access token in the cache
s.userAccessTokenCache.Store(userID, accessToken)
}
if err := scanner.Err(); err != nil {
return err
}
return nil
}

func parseLine(line string) (int64, string) {
parts := strings.Split(line, ":")
if len(parts) != 2 {
return 0, ""
}
userIDStr := parts[0]
accessToken := parts[1]
userID, err := strconv.ParseInt(userIDStr, 10, 64)
if err != nil {
return 0, ""
}
return userID, accessToken
}

0 comments on commit 5d7dc80

Please sign in to comment.