Skip to content

Commit

Permalink
feat: support baichuan's models now (close songquanpeng#1057)
Browse files Browse the repository at this point in the history
  • Loading branch information
songquanpeng committed Mar 1, 2024
1 parent eac6a0b commit 614c2e0
Show file tree
Hide file tree
Showing 12 changed files with 64 additions and 1 deletion.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用
+ [x] [360 智脑](https://ai.360.cn)
+ [x] [腾讯混元大模型](https://cloud.tencent.com/document/product/1729)
+ [x] [Moonshot AI](https://platform.moonshot.cn/)
+ [x] [百川大模型](https://platform.baichuan-ai.com)
+ [ ] [字节云雀大模型](https://www.volcengine.com/product/ark) (WIP)
+ [ ] [MINIMAX](https://api.minimax.chat/) (WIP)
2. 支持配置镜像以及众多[第三方代理服务](https://iamazing.cn/page/openai-api-third-party-services)
Expand Down
2 changes: 2 additions & 0 deletions common/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ const (
ChannelTypeTencent = 23
ChannelTypeGemini = 24
ChannelTypeMoonshot = 25
ChannelTypeBaichuan = 26
)

var ChannelBaseURLs = []string{
Expand Down Expand Up @@ -93,6 +94,7 @@ var ChannelBaseURLs = []string{
"https://hunyuan.cloud.tencent.com", // 23
"https://generativelanguage.googleapis.com", // 24
"https://api.moonshot.cn", // 25
"https://api.baichuan-ai.com", // 26
}

const (
Expand Down
9 changes: 9 additions & 0 deletions common/logger/logger.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
)

const (
loggerDEBUG = "DEBUG"
loggerINFO = "INFO"
loggerWarn = "WARN"
loggerError = "ERR"
Expand Down Expand Up @@ -55,6 +56,10 @@ func SysError(s string) {
_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s)
}

func Debug(ctx context.Context, msg string) {
logHelper(ctx, loggerDEBUG, msg)
}

func Info(ctx context.Context, msg string) {
logHelper(ctx, loggerINFO, msg)
}
Expand All @@ -67,6 +72,10 @@ func Error(ctx context.Context, msg string) {
logHelper(ctx, loggerError, msg)
}

func Debugf(ctx context.Context, format string, a ...any) {
Debug(ctx, fmt.Sprintf(format, a...))
}

func Infof(ctx context.Context, format string, a ...any) {
Info(ctx, fmt.Sprintf(format, a...))
}
Expand Down
12 changes: 12 additions & 0 deletions controller/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/channel/ai360"
"github.com/songquanpeng/one-api/relay/channel/baichuan"
"github.com/songquanpeng/one-api/relay/channel/moonshot"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/helper"
Expand Down Expand Up @@ -98,6 +99,17 @@ func init() {
Parent: nil,
})
}
for _, modelName := range baichuan.ModelList {
openAIModels = append(openAIModels, OpenAIModels{
Id: modelName,
Object: "model",
Created: 1626777600,
OwnedBy: "baichuan",
Permission: permission,
Root: modelName,
Parent: nil,
})
}
openAIModelsMap = make(map[string]OpenAIModels)
for _, model := range openAIModels {
openAIModelsMap[model.Id] = model
Expand Down
4 changes: 4 additions & 0 deletions controller/relay.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ func relay(c *gin.Context, relayMode int) *model.ErrorWithStatusCode {
func Relay(c *gin.Context) {
ctx := c.Request.Context()
relayMode := constant.Path2RelayMode(c.Request.URL.Path)
if config.DebugEnabled {
requestBody, _ := common.GetRequestBody(c)
logger.Debugf(ctx, "request body: %s", string(requestBody))
}
bizErr := relay(c, relayMode)
if bizErr == nil {
return
Expand Down
7 changes: 7 additions & 0 deletions relay/channel/baichuan/constants.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package baichuan

var ModelList = []string{
"Baichuan2-Turbo",
"Baichuan2-Turbo-192k",
"Baichuan-Text-Embedding",
}
5 changes: 5 additions & 0 deletions relay/channel/openai/adaptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/relay/channel"
"github.com/songquanpeng/one-api/relay/channel/ai360"
"github.com/songquanpeng/one-api/relay/channel/baichuan"
"github.com/songquanpeng/one-api/relay/channel/moonshot"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
Expand Down Expand Up @@ -84,6 +85,8 @@ func (a *Adaptor) GetModelList() []string {
return ai360.ModelList
case common.ChannelTypeMoonshot:
return moonshot.ModelList
case common.ChannelTypeBaichuan:
return baichuan.ModelList
default:
return ModelList
}
Expand All @@ -97,6 +100,8 @@ func (a *Adaptor) GetChannelName() string {
return "360"
case common.ChannelTypeMoonshot:
return "moonshot"
case common.ChannelTypeBaichuan:
return "baichuan"
default:
return "openai"
}
Expand Down
3 changes: 2 additions & 1 deletion relay/controller/text.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
var requestBody io.Reader
if meta.APIType == constant.APITypeOpenAI {
// no need to convert request for openai
if isModelMapped {
shouldResetRequestBody := isModelMapped || meta.ChannelType == common.ChannelTypeBaichuan // frequency_penalty 0 is not acceptable for baichuan
if shouldResetRequestBody {
jsonStr, err := json.Marshal(textRequest)
if err != nil {
return openai.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError)
Expand Down
6 changes: 6 additions & 0 deletions web/berry/src/constants/ChannelConstants.js
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,12 @@ export const CHANNEL_OPTIONS = {
value: 23,
color: 'default'
},
26: {
key: 26,
text: '百川大模型',
value: 23,
color: 'default'
},
8: {
key: 8,
text: '自定义渠道',
Expand Down
12 changes: 12 additions & 0 deletions web/berry/src/views/Channel/type/Config.js
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,18 @@ const typeConfig = {
},
modelGroup: "google gemini",
},
25: {
input: {
models: ['moonshot-v1-8k', 'moonshot-v1-32k', 'moonshot-v1-128k'],
},
modelGroup: "moonshot",
},
26: {
input: {
models: ['Baichuan2-Turbo', 'Baichuan2-Turbo-192k', 'Baichuan-Text-Embedding'],
},
modelGroup: "baichuan",
},
};

export { defaultConfig, typeConfig };
1 change: 1 addition & 0 deletions web/default/src/constants/channel.constants.js
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ export const CHANNEL_OPTIONS = [
{ key: 19, text: '360 智脑', value: 19, color: 'blue' },
{ key: 25, text: 'Moonshot AI', value: 25, color: 'black' },
{ key: 23, text: '腾讯混元', value: 23, color: 'teal' },
{ key: 26, text: '百川大模型', value: 26, color: 'orange' },
{ key: 8, text: '自定义渠道', value: 8, color: 'pink' },
{ key: 22, text: '知识库:FastGPT', value: 22, color: 'blue' },
{ key: 21, text: '知识库:AI Proxy', value: 21, color: 'purple' },
Expand Down
3 changes: 3 additions & 0 deletions web/default/src/pages/Channel/EditChannel.js
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,9 @@ const EditChannel = () => {
case 25:
localModels = ['moonshot-v1-8k', 'moonshot-v1-32k', 'moonshot-v1-128k'];
break;
case 26:
localModels = ['Baichuan2-Turbo', 'Baichuan2-Turbo-192k', 'Baichuan-Text-Embedding'];
break;
}
setInputs((inputs) => ({ ...inputs, models: localModels }));
}
Expand Down

0 comments on commit 614c2e0

Please sign in to comment.