forked from coaidev/coai
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add skylark models and baichuan models, fix copy clipboard feature in…
… mobile
- Loading branch information
1 parent
45d527f
commit 71b3b79
Showing
18 changed files
with
977 additions
and
46 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
package baichuan | ||
|
||
import ( | ||
"chat/globals" | ||
"chat/utils" | ||
"fmt" | ||
"strings" | ||
) | ||
|
||
type ChatProps struct { | ||
Model string | ||
Message []globals.Message | ||
Token int | ||
} | ||
|
||
func (c *ChatInstance) GetChatEndpoint() string { | ||
return fmt.Sprintf("%s/v1/chat/completions", c.GetEndpoint()) | ||
} | ||
|
||
func (c *ChatInstance) GetModel(model string) string { | ||
switch model { | ||
case globals.Baichuan53B: | ||
return "Baichuan2" | ||
default: | ||
return model | ||
} | ||
} | ||
|
||
func (c *ChatInstance) GetMessages(messages []globals.Message) []globals.Message { | ||
for _, message := range messages { | ||
if message.Role == globals.System { | ||
message.Role = globals.User | ||
} | ||
} | ||
|
||
return messages | ||
} | ||
|
||
func (c *ChatInstance) GetChatBody(props *ChatProps, stream bool) interface{} { | ||
if props.Token != -1 { | ||
return ChatRequest{ | ||
Model: c.GetModel(props.Model), | ||
Messages: c.GetMessages(props.Message), | ||
MaxToken: props.Token, | ||
Stream: stream, | ||
} | ||
} | ||
|
||
return ChatRequestWithInfinity{ | ||
Model: c.GetModel(props.Model), | ||
Messages: c.GetMessages(props.Message), | ||
Stream: stream, | ||
} | ||
} | ||
|
||
// CreateChatRequest is the native http request body for baichuan | ||
func (c *ChatInstance) CreateChatRequest(props *ChatProps) (string, error) { | ||
res, err := utils.Post( | ||
c.GetChatEndpoint(), | ||
c.GetHeader(), | ||
c.GetChatBody(props, false), | ||
) | ||
|
||
if err != nil || res == nil { | ||
return "", fmt.Errorf("baichuan error: %s", err.Error()) | ||
} | ||
|
||
data := utils.MapToStruct[ChatResponse](res) | ||
if data == nil { | ||
return "", fmt.Errorf("baichuan error: cannot parse response") | ||
} else if data.Error.Message != "" { | ||
return "", fmt.Errorf("baichuan error: %s", data.Error.Message) | ||
} | ||
return data.Choices[0].Message.Content, nil | ||
} | ||
|
||
// CreateStreamChatRequest is the stream response body for baichuan | ||
func (c *ChatInstance) CreateStreamChatRequest(props *ChatProps, callback globals.Hook) error { | ||
buf := "" | ||
cursor := 0 | ||
chunk := "" | ||
|
||
err := utils.EventSource( | ||
"POST", | ||
c.GetChatEndpoint(), | ||
c.GetHeader(), | ||
c.GetChatBody(props, true), | ||
func(data string) error { | ||
data, err := c.ProcessLine(buf, data) | ||
chunk += data | ||
|
||
if err != nil { | ||
if strings.HasPrefix(err.Error(), "baichuan error") { | ||
return err | ||
} | ||
|
||
// error when break line | ||
buf = buf + data | ||
return nil | ||
} | ||
|
||
buf = "" | ||
if data != "" { | ||
cursor += 1 | ||
if err := callback(data); err != nil { | ||
return err | ||
} | ||
} | ||
return nil | ||
}, | ||
) | ||
|
||
if err != nil { | ||
return err | ||
} else if len(chunk) == 0 { | ||
return fmt.Errorf("empty response") | ||
} | ||
|
||
return nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
package baichuan | ||
|
||
import ( | ||
"chat/globals" | ||
"chat/utils" | ||
"errors" | ||
"fmt" | ||
"strings" | ||
) | ||
|
||
func processFormat(data string) string { | ||
rep := strings.NewReplacer( | ||
"data: {", | ||
"\"data\": {", | ||
) | ||
item := rep.Replace(data) | ||
if !strings.HasPrefix(item, "{") { | ||
item = "{" + item | ||
} | ||
if !strings.HasSuffix(item, "}}") { | ||
item = item + "}" | ||
} | ||
|
||
return item | ||
} | ||
|
||
func processChatResponse(data string) *ChatStreamResponse { | ||
if strings.HasPrefix(data, "{") { | ||
var form *ChatStreamResponse | ||
if form = utils.UnmarshalForm[ChatStreamResponse](data); form != nil { | ||
return form | ||
} | ||
|
||
if form = utils.UnmarshalForm[ChatStreamResponse](data[:len(data)-1]); form != nil { | ||
return form | ||
} | ||
|
||
if form = utils.UnmarshalForm[ChatStreamResponse](data + "}"); form != nil { | ||
return form | ||
} | ||
} | ||
|
||
return nil | ||
} | ||
|
||
func processChatErrorResponse(data string) *ChatStreamErrorResponse { | ||
if strings.HasPrefix(data, "{") { | ||
var form *ChatStreamErrorResponse | ||
if form = utils.UnmarshalForm[ChatStreamErrorResponse](data); form != nil { | ||
return form | ||
} | ||
if form = utils.UnmarshalForm[ChatStreamErrorResponse](data + "}"); form != nil { | ||
return form | ||
} | ||
} | ||
|
||
return nil | ||
} | ||
|
||
func isDone(data string) bool { | ||
return utils.Contains[string](data, []string{ | ||
"{data: [DONE]}", "{data: [DONE]}}", "null}}", "{null}", | ||
"{[DONE]}", "{data:}", "{data:}}", "data: [DONE]}}", | ||
}) | ||
} | ||
|
||
func getChoices(form *ChatStreamResponse) string { | ||
if len(form.Data.Choices) == 0 { | ||
if len(form.Choices) > 0 { | ||
return form.Choices[0].Delta.Content | ||
} | ||
} | ||
|
||
return form.Data.Choices[0].Delta.Content | ||
} | ||
|
||
func (c *ChatInstance) ProcessLine(buf, data string) (string, error) { | ||
item := processFormat(buf + data) | ||
if isDone(item) { | ||
return "", nil | ||
} | ||
|
||
fmt.Println(item) | ||
if form := processChatResponse(item); form == nil { | ||
// recursive call | ||
if len(buf) > 0 { | ||
return c.ProcessLine("", buf+item) | ||
} | ||
|
||
if err := processChatErrorResponse(item); err == nil || err.Data.Error.Message == "" { | ||
globals.Warn(fmt.Sprintf("baichuan error: cannot parse response: %s", item)) | ||
return data, errors.New("parser error: cannot parse response") | ||
} else { | ||
return "", fmt.Errorf("baichuan error: %s (type: %s)", err.Data.Error.Message, err.Data.Error.Type) | ||
} | ||
|
||
} else { | ||
return getChoices(form), nil | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
package baichuan | ||
|
||
import ( | ||
"fmt" | ||
"github.com/spf13/viper" | ||
) | ||
|
||
type ChatInstance struct { | ||
Endpoint string | ||
ApiKey string | ||
} | ||
|
||
func (c *ChatInstance) GetEndpoint() string { | ||
return c.Endpoint | ||
} | ||
|
||
func (c *ChatInstance) GetApiKey() string { | ||
return c.ApiKey | ||
} | ||
|
||
func (c *ChatInstance) GetHeader() map[string]string { | ||
return map[string]string{ | ||
"Content-Type": "application/json", | ||
"Authorization": fmt.Sprintf("Bearer %s", c.GetApiKey()), | ||
} | ||
} | ||
|
||
func NewChatInstance(endpoint, apiKey string) *ChatInstance { | ||
return &ChatInstance{ | ||
Endpoint: endpoint, | ||
ApiKey: apiKey, | ||
} | ||
} | ||
|
||
func NewChatInstanceFromConfig() *ChatInstance { | ||
return NewChatInstance( | ||
viper.GetString("baichuan.endpoint"), | ||
viper.GetString("baichuan.apikey"), | ||
) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
package baichuan | ||
|
||
import "chat/globals" | ||
|
||
// Baichuan AI API is similar to OpenAI API | ||
|
||
// ChatRequest is the request body for baichuan | ||
type ChatRequest struct { | ||
Model string `json:"model"` | ||
Messages []globals.Message `json:"messages"` | ||
MaxToken int `json:"max_tokens"` | ||
Stream bool `json:"stream"` | ||
} | ||
|
||
type ChatRequestWithInfinity struct { | ||
Model string `json:"model"` | ||
Messages []globals.Message `json:"messages"` | ||
Stream bool `json:"stream"` | ||
} | ||
|
||
// ChatResponse is the native http request body for baichuan | ||
type ChatResponse struct { | ||
ID string `json:"id"` | ||
Object string `json:"object"` | ||
Created int64 `json:"created"` | ||
Model string `json:"model"` | ||
Choices []struct { | ||
Message struct { | ||
Content string `json:"content"` | ||
} | ||
} `json:"choices"` | ||
Error struct { | ||
Message string `json:"message"` | ||
} `json:"error"` | ||
} | ||
|
||
// ChatStreamResponse is the stream response body for baichuan | ||
type ChatStreamResponse struct { | ||
Data struct { | ||
ID string `json:"id"` | ||
Object string `json:"object"` | ||
Created int64 `json:"created"` | ||
Model string `json:"model"` | ||
Choices []struct { | ||
Delta struct { | ||
Content string `json:"content"` | ||
} | ||
Index int `json:"index"` | ||
} `json:"choices"` | ||
} `json:"data"` | ||
|
||
ID string `json:"id"` | ||
Object string `json:"object"` | ||
Created int64 `json:"created"` | ||
Model string `json:"model"` | ||
Choices []struct { | ||
Delta struct { | ||
Content string `json:"content"` | ||
} | ||
Index int `json:"index"` | ||
} `json:"choices"` | ||
} | ||
|
||
type ChatStreamErrorResponse struct { | ||
Data struct { | ||
Error struct { | ||
Message string `json:"message"` | ||
Type string `json:"type"` | ||
} `json:"error"` | ||
} `json:"data"` | ||
} |
Oops, something went wrong.