Skip to content

Commit

Permalink
update subscription
Browse files Browse the repository at this point in the history
  • Loading branch information
zmh-program committed Sep 11, 2023
1 parent b7b9c3b commit 45bf757
Show file tree
Hide file tree
Showing 13 changed files with 218 additions and 44 deletions.
16 changes: 11 additions & 5 deletions api/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ func SendSegmentMessage(conn *websocket.Conn, message types.ChatGPTSegmentRespon
_ = conn.WriteMessage(websocket.TextMessage, []byte(utils.ToJson(message)))
}

func TextChat(db *sql.DB, user *auth.User, conn *websocket.Conn, instance *conversation.Conversation) string {
func TextChat(db *sql.DB, cache *redis.Client, user *auth.User, conn *websocket.Conn, instance *conversation.Conversation) string {
var keyword string
var segment []types.ChatGPTMessage

Expand All @@ -41,7 +41,8 @@ func TextChat(db *sql.DB, user *auth.User, conn *websocket.Conn, instance *conve

SendSegmentMessage(conn, types.ChatGPTSegmentResponse{Keyword: keyword, End: false})

if instance.IsEnableGPT4() && !auth.CanEnableGPT4(db, user) {
isProPlan := auth.CanEnableSubscription(db, cache, user)
if instance.IsEnableGPT4() && (!isProPlan) && (!auth.CanEnableGPT4(db, user)) {
SendSegmentMessage(conn, types.ChatGPTSegmentResponse{
Message: defaultQuotaMessage,
Quota: 0,
Expand All @@ -51,14 +52,17 @@ func TextChat(db *sql.DB, user *auth.User, conn *websocket.Conn, instance *conve
}

buffer := NewBuffer(instance.IsEnableGPT4(), segment)
StreamRequest(instance.IsEnableGPT4(), segment, 2000, func(resp string) {
StreamRequest(instance.IsEnableGPT4(), isProPlan, segment, 2000, func(resp string) {
SendSegmentMessage(conn, types.ChatGPTSegmentResponse{
Message: buffer.Write(resp),
Quota: buffer.GetQuota(),
End: false,
})
})
if buffer.IsEmpty() {
if isProPlan {
auth.DecreaseSubscriptionUsage(cache, user)
}
SendSegmentMessage(conn, types.ChatGPTSegmentResponse{
Message: defaultErrorMessage,
Quota: -0xe, // special value for error
Expand All @@ -68,7 +72,9 @@ func TextChat(db *sql.DB, user *auth.User, conn *websocket.Conn, instance *conve
}

// collect quota
user.UseQuota(db, buffer.GetQuota())
if !isProPlan {
user.UseQuota(db, buffer.GetQuota())
}
SendSegmentMessage(conn, types.ChatGPTSegmentResponse{End: true, Quota: buffer.GetQuota()})

return buffer.ReadWithDefault(defaultErrorMessage)
Expand Down Expand Up @@ -111,7 +117,7 @@ func ChatHandler(conn *websocket.Conn, instance *conversation.Conversation, user
if strings.HasPrefix(instance.GetLatestMessage(), "/image") {
return ImageChat(conn, instance, user, db, cache)
} else {
return TextChat(db, user, conn, instance)
return TextChat(db, cache, user, conn, instance)
}
}

Expand Down
25 changes: 10 additions & 15 deletions api/image.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,26 +48,21 @@ func GetImageWithCache(ctx context.Context, prompt string, cache *redis.Client)
}

func GetLimitFormat(id int64) string {
t := time.Now().Format("2006-01-02")
return fmt.Sprintf(":imagelimit:%s:%d", t, id)
today := time.Now().Format("2006-01-02")
return fmt.Sprintf(":imagelimit:%s:%d", today, id)
}

func GetImageWithUserLimit(user *auth.User, prompt string, db *sql.DB, cache *redis.Client) (string, error) {
// 5 images one day per user (count by cache)
res, err := cache.Get(context.Background(), GetLimitFormat(user.GetID(db))).Result()
if err != nil || len(res) == 0 || res == "" {
cache.Set(context.Background(), GetLimitFormat(user.GetID(db)), "1", time.Hour*24)
return GetImageWithCache(context.Background(), prompt, cache)
}
// free plan: 5 images per day
// pro plan: 50 images per day

if res == "5" {
if auth.ReduceDalle(db, user) {
return GetImageWithCache(context.Background(), prompt, cache)
}
return "", fmt.Errorf("you have reached your limit of 5 free images per day, please buy more dalle usage or wait until tomorrow")
} else {
cache.Set(context.Background(), GetLimitFormat(user.GetID(db)), fmt.Sprintf("%d", utils.ToInt(res)+1), time.Hour*24)
key := GetLimitFormat(user.GetID(db))
usage := auth.GetDalleUsageLimit(db, user)

if utils.IncrWithLimit(cache, key, 1, int64(usage), 60*60*24) || auth.ReduceDalle(db, user) {
return GetImageWithCache(context.Background(), prompt, cache)
} else {
return "", fmt.Errorf("you have reached your limit of %d free images per day, please buy more quota or wait until tomorrow", usage)
}
}

Expand Down
12 changes: 9 additions & 3 deletions api/stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func NativeStreamRequest(model string, endpoint string, apikeys string, messages

client := &http.Client{}
req, err := http.NewRequest("POST", endpoint+"/chat/completions", utils.ConvertBody(types.ChatGPTRequest{
Model: model,
Model: strings.Replace(model, "reverse", "free", -1),
Messages: messages,
MaxToken: token,
Stream: true,
Expand Down Expand Up @@ -93,9 +93,15 @@ func NativeStreamRequest(model string, endpoint string, apikeys string, messages
}
}

func StreamRequest(enableGPT4 bool, messages []types.ChatGPTMessage, token int, callback func(string)) {
func StreamRequest(enableGPT4 bool, isProPlan bool, messages []types.ChatGPTMessage, token int, callback func(string)) {
var model string
if isProPlan {
model = "gpt-4-reverse" // using reverse engine
} else {
model = "gpt-4"
}
if enableGPT4 {
NativeStreamRequest("gpt-4", viper.GetString("openai.gpt4_endpoint"), viper.GetString("openai.gpt4"), messages, token, callback)
NativeStreamRequest(model, viper.GetString("openai.gpt4_endpoint"), viper.GetString("openai.gpt4"), messages, token, callback)
} else {
NativeStreamRequest("gpt-3.5-turbo-16k-0613", viper.GetString("openai.user_endpoint"), viper.GetString("openai.user"), messages, token, callback)
}
Expand Down
6 changes: 3 additions & 3 deletions app/src/conversation/addition.ts
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ export async function getSubscription(): Promise<SubscriptionResponse> {
}
return {
status: resp.data.status,
is_subscribed: resp.data.data.is_subscribed,
expired: resp.data.data.expired,
is_subscribed: resp.data.is_subscribed,
expired: resp.data.expired,
};
} catch (e) {
console.debug(e);
Expand All @@ -70,7 +70,7 @@ export async function buySubscription(
month: number,
): Promise<BuySubscriptionResponse> {
try {
const resp = await axios.post(`/subscription`, { month });
const resp = await axios.post(`/subscribe`, { month });
return resp.data as BuySubscriptionResponse;
} catch (e) {
console.debug(e);
Expand Down
28 changes: 14 additions & 14 deletions app/src/routes/Subscription.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -76,16 +76,16 @@ function Upgrade({ children }: UpgradeProps) {
</DialogHeader>
<div className="upgrade-wrapper">
<Select onValueChange={
(value: number) => setMonth(value)
(value: string) => setMonth(parseInt(value))
}>
<SelectTrigger className="w-[200px]">
<SelectValue placeholder={t(`sub.time.${month}`)} />
</SelectTrigger>
<SelectContent>
<SelectItem value={1}>{t(`sub.time.1`)}</SelectItem>
<SelectItem value={3}>{t(`sub.time.3`)}</SelectItem>
<SelectItem value={6}>{t(`sub.time.6`)}</SelectItem>
<SelectItem value={12}>
<SelectItem value={"1"}>{t(`sub.time.1`)}</SelectItem>
<SelectItem value={"3"}>{t(`sub.time.3`)}</SelectItem>
<SelectItem value={"6"}>{t(`sub.time.6`)}</SelectItem>
<SelectItem value={"12"}>
{t(`sub.time.12`)}
<Badge className={`ml-2 cent`}>{t(`percent`, { cent: 9 })}</Badge>
</SelectItem>
Expand All @@ -97,16 +97,16 @@ function Upgrade({ children }: UpgradeProps) {
<Button variant={`outline`} onClick={
() => setOpen(false)
}>{ t('cancel') }</Button>
<Button>
<Plus className={`h-4 w-4 mr-1`} onClick={
async () => {
const res = await callBuyAction(t, toast, month);
if (res) {
setOpen(false);
await refreshSubscription(dispatch);
}
<Button onClick={
async () => {
const res = await callBuyAction(t, toast, month);
if (res) {
setOpen(false);
await refreshSubscription(dispatch);
}
} />
}
}>
<Plus className={`h-4 w-4 mr-1`} />
{ t('confirm') }
</Button>
</DialogFooter>
Expand Down
60 changes: 56 additions & 4 deletions auth/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ type BuyForm struct {
Quota int `json:"quota" binding:"required"`
}

type SubscribeForm struct {
Month int `json:"month" binding:"required"`
}

func GetUserByCtx(c *gin.Context) *User {
user := c.MustGet("user").(string)
if len(user) == 0 {
Expand Down Expand Up @@ -50,6 +54,57 @@ func QuotaAPI(c *gin.Context) {
})
}

func SubscriptionAPI(c *gin.Context) {
user := GetUserByCtx(c)
if user == nil {
return
}

db := utils.GetDBFromContext(c)
c.JSON(200, gin.H{
"status": true,
"is_subscribed": user.IsSubscribe(db),
"expired": user.GetSubscriptionExpiredDay(db),
})
}

func SubscribeAPI(c *gin.Context) {
user := GetUserByCtx(c)
if user == nil {
return
}

db := utils.GetDBFromContext(c)
var form SubscribeForm
if err := c.ShouldBindJSON(&form); err != nil {
c.JSON(200, gin.H{
"status": false,
"error": err.Error(),
})
return
}

if form.Month <= 0 || form.Month > 999 {
c.JSON(200, gin.H{
"status": false,
"error": "invalid month range (1 ~ 999)",
})
return
}

if BuySubscription(db, user, form.Month) {
c.JSON(200, gin.H{
"status": true,
"error": "success",
})
} else {
c.JSON(200, gin.H{
"status": false,
"error": "not enough money",
})
}
}

func BuyAPI(c *gin.Context) {
user := GetUserByCtx(c)
if user == nil {
Expand All @@ -74,10 +129,7 @@ func BuyAPI(c *gin.Context) {
return
}

money := float32(form.Quota) * 0.1
if Pay(user.Username, money) {
user.IncreaseQuota(db, float32(form.Quota))

if BuyQuota(db, user, form.Quota) {
c.JSON(200, gin.H{
"status": true,
"error": "success",
Expand Down
49 changes: 49 additions & 0 deletions auth/subscription.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package auth

import (
"chat/utils"
"database/sql"
"fmt"
"github.com/go-redis/redis/v8"
"time"
)

func CountSubscriptionPrize(month int) float32 {
if month >= 12 {
return 8 * float32(month) * 0.9
}
return 8 * float32(month)
}

func BuySubscription(db *sql.DB, user *User, month int) bool {
if month < 1 || month > 999 {
return false
}
money := CountSubscriptionPrize(month)
if Pay(user.Username, money) {
user.AddSubscription(db, month)
return true
}
return false
}

func IncreaseSubscriptionUsage(cache *redis.Client, user *User) bool {
today := time.Now().Format("2006-01-02")
return utils.IncrWithLimit(cache, fmt.Sprintf(":subscription-usage:%s:%d", today, user.ID), 1, 999, 60*60*24) // 1 day
}

func DecreaseSubscriptionUsage(cache *redis.Client, user *User) bool {
today := time.Now().Format("2006-01-02")
return utils.DecrInt(cache, fmt.Sprintf(":subscription-usage:%s:%d", today, user.ID), 1)
}

func CanEnableSubscription(db *sql.DB, cache *redis.Client, user *User) bool {
return user.IsSubscribe(db) && IncreaseSubscriptionUsage(cache, user)
}

func GetDalleUsageLimit(db *sql.DB, user *User) int {
if user.IsSubscribe(db) {
return 50
}
return 5
}
9 changes: 9 additions & 0 deletions auth/usage.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,12 @@ func ReduceDalle(db *sql.DB, user *User) bool {
func CanEnableGPT4(db *sql.DB, user *User) bool {
return user.GetQuota(db) >= 5
}

func BuyQuota(db *sql.DB, user *User, quota int) bool {
money := float32(quota) * 0.1
if Pay(user.Username, money) {
user.IncreaseQuota(db, float32(quota))
return true
}
return false
}
31 changes: 31 additions & 0 deletions auth/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"github.com/gin-gonic/gin"
"github.com/go-redis/redis/v8"
"github.com/spf13/viper"
"math"
"net/http"
"time"
)
Expand Down Expand Up @@ -128,6 +129,36 @@ func (u *User) UseQuota(db *sql.DB, quota float32) bool {
return u.IncreaseUsedQuota(db, quota)
}

func (u *User) GetSubscription(db *sql.DB) time.Time {
var expiredAt []uint8
if err := db.QueryRow("SELECT expired_at FROM subscription WHERE user_id = ?", u.GetID(db)).Scan(&expiredAt); err != nil {
return time.Unix(0, 0)
}
return *utils.ConvertTime(expiredAt)
}

func (u *User) IsSubscribe(db *sql.DB) bool {
return u.GetSubscription(db).Unix() > time.Now().Unix()
}

func (u *User) GetSubscriptionExpiredDay(db *sql.DB) int {
stamp := u.GetSubscription(db).Sub(time.Now())
return int(math.Round(stamp.Hours() / 24))
}

func (u *User) AddSubscription(db *sql.DB, month int) bool {
current := u.GetSubscription(db)
if current.Unix() < time.Now().Unix() {
current = time.Now()
}
expiredAt := current.AddDate(0, month, 0)
_, err := db.Exec(`
INSERT INTO subscription (user_id, expired_at, total_month) VALUES (?, ?, ?)
ON DUPLICATE KEY UPDATE expired_at = ?, total_month = total_month + ?
`, u.GetID(db), utils.ConvertSqlTime(expiredAt), month, utils.ConvertSqlTime(expiredAt), month)
return err == nil
}

func IsUserExist(db *sql.DB, username string) bool {
var count int
if err := db.QueryRow("SELECT COUNT(*) FROM auth WHERE username = ?", username).Scan(&count); err != nil {
Expand Down
Loading

0 comments on commit 45bf757

Please sign in to comment.