Skip to content

Commit

Permalink
feat: add authorization for MidJourney function calls
Browse files Browse the repository at this point in the history
  • Loading branch information
yangjian102621 committed Aug 16, 2023
1 parent c8998ba commit fab4309
Show file tree
Hide file tree
Showing 12 changed files with 79 additions and 56 deletions.
1 change: 1 addition & 0 deletions api/core/types/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ type SystemConfig struct {
AdminTitle string `json:"admin_title"`
Models []string `json:"models"`
UserInitCalls int `json:"user_init_calls"` // 新用户注册默认总送多少次调用
InitImgCalls int `json:"init_img_calls"`
EnabledRegister bool `json:"enabled_register"`
EnabledMsgService bool `json:"enabled_msg_service"`
}
2 changes: 2 additions & 0 deletions api/handler/admin/user_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ func (h *UserHandler) Save(c *gin.Context) {
Mobile string `json:"mobile"`
Nickname string `json:"nickname"`
Calls int `json:"calls"`
ImgCalls int `json:"img_calls"`
ChatRoles []string `json:"chat_roles"`
ExpiredTime string `json:"expired_time"`
Status bool `json:"status"`
Expand All @@ -91,6 +92,7 @@ func (h *UserHandler) Save(c *gin.Context) {
"nickname": data.Nickname,
"mobile": data.Mobile,
"calls": data.Calls,
"img_calls": data.ImgCalls,
"status": data.Status,
"chat_roles_json": utils.JsonEncode(data.ChatRoles),
"expired_time": utils.Str2stamp(data.ExpiredTime),
Expand Down
80 changes: 43 additions & 37 deletions api/handler/chat_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -326,55 +326,61 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
} // end for

if functionCall { // 调用函数完成任务
logger.Info("函数名称:", functionName)
var params map[string]interface{}
_ = utils.JsonDecode(strings.Join(arguments, ""), &params)
logger.Info("函数参数:", params)
f := h.App.Functions[functionName]
data, err := f.Invoke(params)
if err != nil {
msg := "调用函数出错:" + err.Error()
utils.ReplyChunkMessage(ws, types.WsMessage{
Type: types.WsMiddle,
Content: msg,
})
contents = append(contents, msg)
logger.Debugf("函数名称: %s, 函数参数:%s", functionName, params)

// for creating image, check if the user's img_calls > 0
if functionName == types.FuncMidJourney && userVo.ImgCalls <= 0 {
utils.ReplyMessage(ws, "**当前用户剩余绘图次数已用尽,请扫描下面二维码联系管理员!**")
utils.ReplyMessage(ws, "![](/images/wx.png)")
} else {
content := data
if functionName == types.FuncMidJourney {
key := utils.Sha256(data)
//logger.Info(data, ",", key)
// add task for MidJourney
h.App.MjTaskClients.Put(key, ws)
task := types.MjTask{
UserId: userVo.Id,
RoleId: role.Id,
Icon: "/images/avatar/mid_journey.png",
ChatId: session.ChatId,
}
err := h.leveldb.Put(types.TaskStorePrefix+key, task)
if err != nil {
logger.Error("error with store MidJourney task: ", err)
f := h.App.Functions[functionName]
data, err := f.Invoke(params)
if err != nil {
msg := "调用函数出错:" + err.Error()
utils.ReplyChunkMessage(ws, types.WsMessage{
Type: types.WsMiddle,
Content: msg,
})
contents = append(contents, msg)
} else {
content := data
if functionName == types.FuncMidJourney {
key := utils.Sha256(data)
logger.Debug(data, ",", key)
// add task for MidJourney
h.App.MjTaskClients.Put(key, ws)
task := types.MjTask{
UserId: userVo.Id,
RoleId: role.Id,
Icon: "/images/avatar/mid_journey.png",
ChatId: session.ChatId,
}
err := h.leveldb.Put(types.TaskStorePrefix+key, task)
if err != nil {
logger.Error("error with store MidJourney task: ", err)
}
content = fmt.Sprintf("绘画提示词:%s 已推送任务到 MidJourney 机器人,请耐心等待任务执行...", data)

// update user's img_calls
h.db.Model(&user).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1))
}
content = fmt.Sprintf("绘画提示词:%s 已推送任务到 MidJourney 机器人,请耐心等待任务执行...", data)
}

utils.ReplyChunkMessage(ws, types.WsMessage{
Type: types.WsMiddle,
Content: content,
})
contents = append(contents, content)
utils.ReplyChunkMessage(ws, types.WsMessage{
Type: types.WsMiddle,
Content: content,
})
contents = append(contents, content)
}
}
}

// 消息发送成功
if len(contents) > 0 {
// 更新用户的对话次数
if userVo.ChatConfig.ApiKey == "" { // 如果用户使用的是自己绑定的 API KEY 则不扣减对话次数
res := h.db.Model(&user).UpdateColumn("calls", gorm.Expr("calls - ?", 1))
if res.Error != nil {
return res.Error
}
h.db.Model(&user).UpdateColumn("calls", gorm.Expr("calls - ?", 1))
}

if message.Role == "" {
Expand Down
1 change: 1 addition & 0 deletions api/handler/mj_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ func (h *MidJourneyHandler) Notify(c *gin.Context) {
job.UserId = task.UserId
job.ChatId = task.ChatId
job.MessageId = data.MessageId
job.ReferenceId = data.ReferenceId
job.Content = data.Content
job.Prompt = data.Prompt
job.Image = utils.JsonEncode(data.Image)
Expand Down
3 changes: 2 additions & 1 deletion api/handler/user_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,8 @@ func (h *UserHandler) Register(c *gin.Context) {
Model: h.App.ChatConfig.Model,
ApiKey: "",
}),
Calls: h.App.SysConfig.UserInitCalls,
Calls: h.App.SysConfig.UserInitCalls,
ImgCalls: h.App.SysConfig.InitImgCalls,
}
res = h.db.Create(&user)
if res.Error != nil {
Expand Down
19 changes: 10 additions & 9 deletions api/store/model/mj_job.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,16 @@ package model
import "time"

type MidJourneyJob struct {
Id uint `gorm:"primarykey;column:id"`
UserId uint
ChatId string
MessageId string
Hash string
Content string
Prompt string
Image string
CreatedAt time.Time
Id uint `gorm:"primarykey;column:id"`
UserId uint
ChatId string
MessageId string
ReferenceId string
Hash string
Content string
Prompt string
Image string
CreatedAt time.Time
}

func (MidJourneyJob) TableName() string {
Expand Down
1 change: 1 addition & 0 deletions api/store/model/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ type User struct {
Salt string // 密码盐
Tokens int64 // 剩余tokens
Calls int // 剩余对话次数
ImgCalls int // 剩余绘图次数
ChatConfig string `gorm:"column:chat_config_json"` // 聊天配置 json
ChatRoles string `gorm:"column:chat_roles_json"` // 聊天角色
ExpiredTime int64 // 账户到期时间
Expand Down
7 changes: 4 additions & 3 deletions api/store/vo/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@ type User struct {
Mobile string `json:"mobile"`
Nickname string `json:"nickname"`
Avatar string `json:"avatar"`
Salt string `json:"salt"` // 密码盐
Tokens int64 `json:"tokens"` // 剩余tokens
Calls int `json:"calls"` // 剩余对话次数
Salt string `json:"salt"` // 密码盐
Tokens int64 `json:"tokens"` // 剩余tokens
Calls int `json:"calls"` // 剩余对话次数
ImgCalls int `json:"img_calls"`
ChatConfig types.ChatConfig `json:"chat_config"` // 聊天配置
ChatRoles []string `json:"chat_roles"` // 聊天角色集合
ExpiredTime int64 `json:"expired_time"` // 账户到期时间
Expand Down
5 changes: 4 additions & 1 deletion database/update-v3.0.7.sql
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,7 @@ ALTER TABLE `chatgpt_mj_jobs`
--
ALTER TABLE `chatgpt_mj_jobs`
MODIFY `id` int NOT NULL AUTO_INCREMENT;
COMMIT;

ALTER TABLE `chatgpt_mj_jobs` ADD `reference_id` CHAR(40) NULL DEFAULT NULL COMMENT '引用消息 ID' AFTER `message_id`;

ALTER TABLE `chatgpt_users` ADD `img_calls` INT NOT NULL DEFAULT '0' COMMENT '剩余绘图次数' AFTER `calls`;
8 changes: 4 additions & 4 deletions web/src/views/admin/RewardList.vue
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
<el-table-column prop="username" label="用户名"/>
<el-table-column prop="tx_id" label="转账单号"/>
<el-table-column prop="amount" label="转账金额"/>
<el-table-column prop="remark" label="备注"/>

<el-table-column label="转账时间">
<template #default="scope">
Expand All @@ -27,11 +28,10 @@
</template>

<script setup>
import {reactive, ref} from "vue";
import {httpGet, httpPost} from "@/utils/http";
import {ref} from "vue";
import {httpGet} from "@/utils/http";
import {ElMessage} from "element-plus";
import {dateFormat, disabledDate, removeArrayItem} from "@/utils/libs";
import {Plus} from "@element-plus/icons-vue";
import {dateFormat} from "@/utils/libs";
// 变量定义
const items = ref([])
Expand Down
5 changes: 4 additions & 1 deletion web/src/views/admin/SysConfig.vue
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,12 @@
<el-form-item label="控制台标题" prop="admin_title">
<el-input v-model="system['admin_title']"/>
</el-form-item>
<el-form-item label="注册赠送次数" prop="init_calls">
<el-form-item label="赠送对话次数" prop="init_calls">
<el-input v-model.number="system['user_init_calls']" placeholder="新用户注册赠送对话次数"/>
</el-form-item>
<el-form-item label="赠送绘图次数" prop="init_calls">
<el-input v-model.number="system['init_img_calls']" placeholder="新用户注册赠送绘图次数"/>
</el-form-item>
<el-form-item label="短信验证服务" prop="enabled_msg_service">
<el-switch v-model="system['enabled_msg_service']"/>
</el-form-item>
Expand Down
3 changes: 3 additions & 0 deletions web/src/views/admin/UserList.vue
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@
<el-form-item label="提问次数:" prop="calls">
<el-input v-model.number="user.calls" autocomplete="off" placeholder="0"/>
</el-form-item>
<el-form-item label="绘图次数:" prop="img_calls">
<el-input v-model.number="user['img_calls']" autocomplete="off" placeholder="0"/>
</el-form-item>

<el-form-item label="有效期:" prop="expired_time">
<el-date-picker
Expand Down

0 comments on commit fab4309

Please sign in to comment.