Skip to content

Commit

Permalink
MartialBE-main (#159)
Browse files Browse the repository at this point in the history
  • Loading branch information
woodchen-ink authored Oct 7, 2024
2 parents c14108b + 7e3dd26 commit db3b276
Show file tree
Hide file tree
Showing 14 changed files with 292 additions and 109 deletions.
47 changes: 44 additions & 3 deletions common/cache/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,16 @@ import (
"github.com/eko/gocache/lib/v4/store"
freecache_store "github.com/eko/gocache/store/freecache/v4"
redis_store "github.com/eko/gocache/store/redis/v4"
"golang.org/x/sync/singleflight"
)

var kvCache *marshaler.Marshaler
var ctx = context.Background()
var (
kvCache *marshaler.Marshaler
ctx = context.Background()
sfGroup singleflight.Group
CacheTimeout = 500 * time.Millisecond
CacheNotFound = errors.New("cache not found")
)

func InitCacheManager() {
var client *cacheM.Cache[any]
Expand All @@ -36,7 +42,7 @@ func GetCache[T any](key string) (T, error) {
_, err := kvCache.Get(ctx, key, &val)
if err != nil {
if errors.Is(err, store.NotFound{}) {
return val, nil
return *new(T), CacheNotFound
}
return *new(T), err
}
Expand All @@ -50,3 +56,38 @@ func SetCache(key string, value any, expiration time.Duration) error {
func DeleteCache(key string) error {
return kvCache.Delete(ctx, key)
}

func GetOrSetCache[T any](key string, expiration time.Duration, fn func() (T, error), timeout time.Duration) (T, error) {
v, err := GetCache[T](key)
if err == nil {
return v, nil
}

if !errors.Is(err, CacheNotFound) {
return *new(T), err
}

result := sfGroup.DoChan(key, func() (interface{}, error) {
v, err := fn()
if err != nil {
return nil, err
}

SetCache(key, v, expiration)

return v, nil
})

t := time.After(timeout)

select {
case r := <-result:
v, ok := r.Val.(T)
if !ok {
return *new(T), errors.New("类型断言失败")
}
return v, r.Err
case <-t:
return *new(T), errors.New("超时")
}
}
112 changes: 58 additions & 54 deletions model/cache.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
package model

import (
"encoding/json"
"fmt"
"one-api/common"
"one-api/common/cache"
"one-api/common/config"
"one-api/common/logger"
"one-api/common/redis"
Expand All @@ -16,50 +15,38 @@ var (
)

func CacheGetTokenByKey(key string) (*Token, error) {
keyCol := "`key`"
if common.UsingPostgreSQL {
keyCol = `"key"`
}
var token Token
if !config.RedisEnabled {
err := DB.Where(keyCol+" = ?", key).First(&token).Error
return &token, err
}
tokenObjectString, err := redis.RedisGet(fmt.Sprintf("token:%s", key))
if err != nil {
err := DB.Where(keyCol+" = ?", key).First(&token).Error
if err != nil {
return nil, err
}
jsonBytes, err := json.Marshal(token)
if err != nil {
return nil, err
}
err = redis.RedisSet(fmt.Sprintf("token:%s", key), string(jsonBytes), time.Duration(TokenCacheSeconds)*time.Second)
if err != nil {
logger.SysError("Redis set token error: " + err.Error())
}
return &token, nil
return GetTokenByKey(key)
}
err = json.Unmarshal([]byte(tokenObjectString), &token)
return &token, err

token, err := cache.GetOrSetCache(
fmt.Sprintf("token:%s", key),
time.Duration(TokenCacheSeconds)*time.Second,
func() (*Token, error) {
return GetTokenByKey(key)
},
cache.CacheTimeout)

return token, err
}

func CacheGetUserGroup(id int) (group string, err error) {
if !config.RedisEnabled {
return GetUserGroup(id)
}
group, err = redis.RedisGet(fmt.Sprintf("user_group:%d", id))
if err != nil {
group, err = GetUserGroup(id)
if err != nil {
return "", err
}
err = redis.RedisSet(fmt.Sprintf("user_group:%d", id), group, time.Duration(TokenCacheSeconds)*time.Second)
if err != nil {
logger.SysError("Redis set user group error: " + err.Error())
}
}

group, err = cache.GetOrSetCache(
fmt.Sprintf("user_group:%d", id),
time.Duration(TokenCacheSeconds)*time.Second,
func() (string, error) {
groupId, err := GetUserGroup(id)
if err != nil {
return "", err
}
return groupId, nil
},
cache.CacheTimeout)

return group, err
}

Expand Down Expand Up @@ -107,22 +94,39 @@ func CacheIsUserEnabled(userId int) (bool, error) {
if !config.RedisEnabled {
return IsUserEnabled(userId)
}
enabled, err := redis.RedisGet(fmt.Sprintf("user_enabled:%d", userId))
if err == nil {
return enabled == "1", nil
}

userEnabled, err := IsUserEnabled(userId)
if err != nil {
return false, err
}
enabled = "0"
if userEnabled {
enabled = "1"
}
err = redis.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, time.Duration(TokenCacheSeconds)*time.Second)
if err != nil {
logger.SysError("Redis set user enabled error: " + err.Error())
enabled, err := cache.GetOrSetCache(
fmt.Sprintf("user_enabled:%d", userId),
time.Duration(TokenCacheSeconds)*time.Second,
func() (bool, error) {
enabled, err := IsUserEnabled(userId)
if err != nil {
return false, err
}
return enabled, nil
},
cache.CacheTimeout)

return enabled, err
}

func CacheGetUsername(id int) (username string, err error) {
if !config.RedisEnabled {
return GetUsernameById(id), nil
}
return userEnabled, err

username, err = cache.GetOrSetCache(
fmt.Sprintf("user_name:%d", id),
time.Duration(TokenCacheSeconds)*time.Second,
func() (string, error) {
username := GetUsernameById(id)
if username == "" {
return "", fmt.Errorf("user %d not found", id)
}

return username, nil
},
cache.CacheTimeout)

return username, err
}
33 changes: 30 additions & 3 deletions model/log.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"one-api/common/logger"
"one-api/common/utils"

"gorm.io/datatypes"
"gorm.io/gorm"
)

Expand All @@ -24,6 +25,9 @@ type Log struct {
CompletionTokens int `json:"completion_tokens" gorm:"default:0"`
ChannelId int `json:"channel_id" gorm:"index"`
RequestTime int `json:"request_time" gorm:"default:0"`
IsStream bool `json:"is_stream" gorm:"default:false"`

Metadata datatypes.JSONType[map[string]any] `json:"metadata" gorm:"type:json"`

Channel *Channel `json:"channel" gorm:"foreignKey:Id;references:ChannelId"`
}
Expand All @@ -40,9 +44,11 @@ func RecordLog(userId int, logType int, content string) {
if logType == LogTypeConsume && !config.LogConsumeEnabled {
return
}
username, _ := CacheGetUsername(userId)

log := &Log{
UserId: userId,
Username: GetUsernameById(userId),
Username: username,
CreatedAt: utils.GetTimestamp(),
Type: logType,
Content: content,
Expand All @@ -53,14 +59,29 @@ func RecordLog(userId int, logType int, content string) {
}
}

func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string, requestTime int) {
func RecordConsumeLog(
ctx context.Context,
userId int,
channelId int,
promptTokens int,
completionTokens int,
modelName string,
tokenName string,
quota int,
content string,
requestTime int,
isStream bool,
metadata map[string]any) {
logger.LogInfo(ctx, fmt.Sprintf("record consume log: userId=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content))
if !config.LogConsumeEnabled {
return
}

username, _ := CacheGetUsername(userId)

log := &Log{
UserId: userId,
Username: GetUsernameById(userId),
Username: username,
CreatedAt: utils.GetTimestamp(),
Type: LogTypeConsume,
Content: content,
Expand All @@ -71,7 +92,13 @@ func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptToke
Quota: quota,
ChannelId: channelId,
RequestTime: requestTime,
IsStream: isStream,
}

if metadata != nil {
log.Metadata = datatypes.NewJSONType(metadata)
}

err := DB.Create(log).Error
if err != nil {
logger.LogError(ctx, "failed to record log: "+err.Error())
Expand Down
13 changes: 13 additions & 0 deletions model/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package model
import (
"errors"
"fmt"
"one-api/common"
"one-api/common/config"
"one-api/common/logger"
"one-api/common/redis"
Expand Down Expand Up @@ -122,6 +123,18 @@ func GetTokenByName(name string, userId int) (*Token, error) {
return &token, err
}

func GetTokenByKey(key string) (*Token, error) {
keyCol := "`key`"
if common.UsingPostgreSQL {
keyCol = `"key"`
}

var token Token

err := DB.Where(keyCol+" = ?", key).First(&token).Error
return &token, err
}

func (token *Token) Insert() error {
if token.ChatCache && !config.ChatCacheEnabled {
token.ChatCache = false
Expand Down
5 changes: 3 additions & 2 deletions providers/azure/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@ func (f AzureProviderFactory) Create(channel *model.Channel) base.ProviderInterf
Channel: channel,
Requester: requester.NewHTTPRequester(*channel.Proxy, openai.RequestErrorHandle),
},
IsAzure: true,
BalanceAction: false,
IsAzure: true,
BalanceAction: false,
SupportStreamOptions: true,
},
}
}
Expand Down
2 changes: 1 addition & 1 deletion relay/claude.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ func RelayClaudeHandler(c *gin.Context, promptTokens int, chatProvider claude.Cl
return
}

quota.Consume(c, usage)
quota.Consume(c, usage, request.Stream)
if usage.CompletionTokens > 0 {
go cache.StoreCache(c.GetInt("channel_id"), usage.PromptTokens, usage.CompletionTokens, originalModel)
}
Expand Down
2 changes: 1 addition & 1 deletion relay/gemini.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ func RelayGeminiHandler(c *gin.Context, promptTokens int, chatProvider gemini.Ge
return
}

quota.Consume(c, usage)
quota.Consume(c, usage, request.Stream)
if usage.CompletionTokens > 0 {
go cache.StoreCache(c.GetInt("channel_id"), usage.PromptTokens, usage.CompletionTokens, originalModel)
}
Expand Down
4 changes: 2 additions & 2 deletions relay/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ func RelayHandler(relay RelayBaseInterface) (err *types.OpenAIErrorWithStatusCod
return
}

quota.Consume(relay.getContext(), usage)
quota.Consume(relay.getContext(), usage, relay.IsStream())
if usage.CompletionTokens > 0 {
cacheProps := relay.GetChatCache()
go cacheProps.StoreCache(relay.getContext().GetInt("channel_id"), usage.PromptTokens, usage.CompletionTokens, relay.getModelName())
Expand All @@ -138,5 +138,5 @@ func cacheProcessing(c *gin.Context, cacheProps *relay_util.ChatCacheProps, isSt
}
}

model.RecordConsumeLog(c.Request.Context(), cacheProps.UserId, cacheProps.ChannelID, cacheProps.PromptTokens, cacheProps.CompletionTokens, cacheProps.ModelName, tokenName, 0, "缓存", requestTime)
model.RecordConsumeLog(c.Request.Context(), cacheProps.UserId, cacheProps.ChannelID, cacheProps.PromptTokens, cacheProps.CompletionTokens, cacheProps.ModelName, tokenName, 0, "缓存", requestTime, isStream, nil)
}
4 changes: 2 additions & 2 deletions relay/midjourney/relay-mj.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ func RelaySwapFace(c *gin.Context) *provider.MidjourneyResponse {

defer func(ctx context.Context) {
if mjResp.StatusCode == 200 && mjResp.Response.Code == 1 {
quotaInstance.Consume(c, &types.Usage{CompletionTokens: 0, PromptTokens: 1, TotalTokens: 1})
quotaInstance.Consume(c, &types.Usage{CompletionTokens: 0, PromptTokens: 1, TotalTokens: 1}, false)
} else {
quotaInstance.Undo(c)
}
Expand Down Expand Up @@ -429,7 +429,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *provider.MidjourneyRe

defer func(ctx context.Context) {
if consumeQuota && midjResponseWithStatus.StatusCode == 200 {
quotaInstance.Consume(c, &types.Usage{CompletionTokens: 0, PromptTokens: 1, TotalTokens: 1})
quotaInstance.Consume(c, &types.Usage{CompletionTokens: 0, PromptTokens: 1, TotalTokens: 1}, false)
} else {
quotaInstance.Undo(c)
}
Expand Down
2 changes: 1 addition & 1 deletion relay/relay.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,6 @@ func RelayOnly(c *gin.Context) {
requestTime = int(time.Since(requestStartTime).Milliseconds())
}
}
model.RecordConsumeLog(c.Request.Context(), c.GetInt("id"), c.GetInt("channel_id"), 0, 0, "", c.GetString("token_name"), 0, "中继:"+path, requestTime)
model.RecordConsumeLog(c.Request.Context(), c.GetInt("id"), c.GetInt("channel_id"), 0, 0, "", c.GetString("token_name"), 0, "中继:"+path, requestTime, false, nil)

}
Loading

0 comments on commit db3b276

Please sign in to comment.