From 45bf75709d52e4af73bc383a323369ccc2342f40 Mon Sep 17 00:00:00 2001 From: Zhang Minghan Date: Mon, 11 Sep 2023 14:44:22 +0800 Subject: [PATCH] update subscription --- api/chat.go | 16 ++++++--- api/image.go | 25 ++++++------- api/stream.go | 12 +++++-- app/src/conversation/addition.ts | 6 ++-- app/src/routes/Subscription.tsx | 28 +++++++-------- auth/controller.go | 60 +++++++++++++++++++++++++++++--- auth/subscription.go | 49 ++++++++++++++++++++++++++ auth/usage.go | 9 +++++ auth/user.go | 31 +++++++++++++++++ connection/database.go | 18 ++++++++++ main.go | 2 ++ middleware/throttle.go | 2 ++ utils/char.go | 4 +++ 13 files changed, 218 insertions(+), 44 deletions(-) create mode 100644 auth/subscription.go diff --git a/api/chat.go b/api/chat.go index 00266210..3c95365d 100644 --- a/api/chat.go +++ b/api/chat.go @@ -29,7 +29,7 @@ func SendSegmentMessage(conn *websocket.Conn, message types.ChatGPTSegmentRespon _ = conn.WriteMessage(websocket.TextMessage, []byte(utils.ToJson(message))) } -func TextChat(db *sql.DB, user *auth.User, conn *websocket.Conn, instance *conversation.Conversation) string { +func TextChat(db *sql.DB, cache *redis.Client, user *auth.User, conn *websocket.Conn, instance *conversation.Conversation) string { var keyword string var segment []types.ChatGPTMessage @@ -41,7 +41,8 @@ func TextChat(db *sql.DB, user *auth.User, conn *websocket.Conn, instance *conve SendSegmentMessage(conn, types.ChatGPTSegmentResponse{Keyword: keyword, End: false}) - if instance.IsEnableGPT4() && !auth.CanEnableGPT4(db, user) { + isProPlan := auth.CanEnableSubscription(db, cache, user) + if instance.IsEnableGPT4() && (!isProPlan) && (!auth.CanEnableGPT4(db, user)) { SendSegmentMessage(conn, types.ChatGPTSegmentResponse{ Message: defaultQuotaMessage, Quota: 0, @@ -51,7 +52,7 @@ func TextChat(db *sql.DB, user *auth.User, conn *websocket.Conn, instance *conve } buffer := NewBuffer(instance.IsEnableGPT4(), segment) - StreamRequest(instance.IsEnableGPT4(), segment, 2000, func(resp string) { + StreamRequest(instance.IsEnableGPT4(), isProPlan, segment, 2000, func(resp string) { SendSegmentMessage(conn, types.ChatGPTSegmentResponse{ Message: buffer.Write(resp), Quota: buffer.GetQuota(), @@ -59,6 +60,9 @@ func TextChat(db *sql.DB, user *auth.User, conn *websocket.Conn, instance *conve }) }) if buffer.IsEmpty() { + if isProPlan { + auth.DecreaseSubscriptionUsage(cache, user) + } SendSegmentMessage(conn, types.ChatGPTSegmentResponse{ Message: defaultErrorMessage, Quota: -0xe, // special value for error @@ -68,7 +72,9 @@ func TextChat(db *sql.DB, user *auth.User, conn *websocket.Conn, instance *conve } // collect quota - user.UseQuota(db, buffer.GetQuota()) + if !isProPlan { + user.UseQuota(db, buffer.GetQuota()) + } SendSegmentMessage(conn, types.ChatGPTSegmentResponse{End: true, Quota: buffer.GetQuota()}) return buffer.ReadWithDefault(defaultErrorMessage) @@ -111,7 +117,7 @@ func ChatHandler(conn *websocket.Conn, instance *conversation.Conversation, user if strings.HasPrefix(instance.GetLatestMessage(), "/image") { return ImageChat(conn, instance, user, db, cache) } else { - return TextChat(db, user, conn, instance) + return TextChat(db, cache, user, conn, instance) } } diff --git a/api/image.go b/api/image.go index 5bfaa61e..21c1d4e0 100644 --- a/api/image.go +++ b/api/image.go @@ -48,26 +48,21 @@ func GetImageWithCache(ctx context.Context, prompt string, cache *redis.Client) } func GetLimitFormat(id int64) string { - t := time.Now().Format("2006-01-02") - return fmt.Sprintf(":imagelimit:%s:%d", t, id) + today := time.Now().Format("2006-01-02") + return fmt.Sprintf(":imagelimit:%s:%d", today, id) } func GetImageWithUserLimit(user *auth.User, prompt string, db *sql.DB, cache *redis.Client) (string, error) { - // 5 images one day per user (count by cache) - res, err := cache.Get(context.Background(), GetLimitFormat(user.GetID(db))).Result() - if err != nil || len(res) == 0 || res == "" { - cache.Set(context.Background(), GetLimitFormat(user.GetID(db)), "1", time.Hour*24) - return GetImageWithCache(context.Background(), prompt, cache) - } + // free plan: 5 images per day + // pro plan: 50 images per day - if res == "5" { - if auth.ReduceDalle(db, user) { - return GetImageWithCache(context.Background(), prompt, cache) - } - return "", fmt.Errorf("you have reached your limit of 5 free images per day, please buy more dalle usage or wait until tomorrow") - } else { - cache.Set(context.Background(), GetLimitFormat(user.GetID(db)), fmt.Sprintf("%d", utils.ToInt(res)+1), time.Hour*24) + key := GetLimitFormat(user.GetID(db)) + usage := auth.GetDalleUsageLimit(db, user) + + if utils.IncrWithLimit(cache, key, 1, int64(usage), 60*60*24) || auth.ReduceDalle(db, user) { return GetImageWithCache(context.Background(), prompt, cache) + } else { + return "", fmt.Errorf("you have reached your limit of %d free images per day, please buy more quota or wait until tomorrow", usage) } } diff --git a/api/stream.go b/api/stream.go index bab6ea05..8c93897b 100644 --- a/api/stream.go +++ b/api/stream.go @@ -56,7 +56,7 @@ func NativeStreamRequest(model string, endpoint string, apikeys string, messages client := &http.Client{} req, err := http.NewRequest("POST", endpoint+"/chat/completions", utils.ConvertBody(types.ChatGPTRequest{ - Model: model, + Model: strings.Replace(model, "reverse", "free", -1), Messages: messages, MaxToken: token, Stream: true, @@ -93,9 +93,15 @@ func NativeStreamRequest(model string, endpoint string, apikeys string, messages } } -func StreamRequest(enableGPT4 bool, messages []types.ChatGPTMessage, token int, callback func(string)) { +func StreamRequest(enableGPT4 bool, isProPlan bool, messages []types.ChatGPTMessage, token int, callback func(string)) { + var model string + if isProPlan { + model = "gpt-4-reverse" // using reverse engine + } else { + model = "gpt-4" + } if enableGPT4 { - NativeStreamRequest("gpt-4", viper.GetString("openai.gpt4_endpoint"), viper.GetString("openai.gpt4"), messages, token, callback) + NativeStreamRequest(model, viper.GetString("openai.gpt4_endpoint"), viper.GetString("openai.gpt4"), messages, token, callback) } else { NativeStreamRequest("gpt-3.5-turbo-16k-0613", viper.GetString("openai.user_endpoint"), viper.GetString("openai.user"), messages, token, callback) } diff --git a/app/src/conversation/addition.ts b/app/src/conversation/addition.ts index 5bda6933..971cd271 100644 --- a/app/src/conversation/addition.ts +++ b/app/src/conversation/addition.ts @@ -57,8 +57,8 @@ export async function getSubscription(): Promise { } return { status: resp.data.status, - is_subscribed: resp.data.data.is_subscribed, - expired: resp.data.data.expired, + is_subscribed: resp.data.is_subscribed, + expired: resp.data.expired, }; } catch (e) { console.debug(e); @@ -70,7 +70,7 @@ export async function buySubscription( month: number, ): Promise { try { - const resp = await axios.post(`/subscription`, { month }); + const resp = await axios.post(`/subscribe`, { month }); return resp.data as BuySubscriptionResponse; } catch (e) { console.debug(e); diff --git a/app/src/routes/Subscription.tsx b/app/src/routes/Subscription.tsx index 4acc1dfe..0813c090 100644 --- a/app/src/routes/Subscription.tsx +++ b/app/src/routes/Subscription.tsx @@ -76,16 +76,16 @@ function Upgrade({ children }: UpgradeProps) {