forked from songquanpeng/one-api
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: support baidu's models now (close songquanpeng#286)
- Loading branch information
1 parent
3c94011
commit 9a1db61
Showing
7 changed files
with
268 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,203 @@ | ||
package controller | ||
|
||
import ( | ||
"bufio" | ||
"encoding/json" | ||
"github.com/gin-gonic/gin" | ||
"io" | ||
"net/http" | ||
"one-api/common" | ||
"strings" | ||
) | ||
|
||
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/flfmc9do2 | ||
|
||
type BaiduTokenResponse struct { | ||
RefreshToken string `json:"refresh_token"` | ||
ExpiresIn int `json:"expires_in"` | ||
SessionKey string `json:"session_key"` | ||
AccessToken string `json:"access_token"` | ||
Scope string `json:"scope"` | ||
SessionSecret string `json:"session_secret"` | ||
} | ||
|
||
type BaiduMessage struct { | ||
Role string `json:"role"` | ||
Content string `json:"content"` | ||
} | ||
|
||
type BaiduChatRequest struct { | ||
Messages []BaiduMessage `json:"messages"` | ||
Stream bool `json:"stream"` | ||
UserId string `json:"user_id,omitempty"` | ||
} | ||
|
||
type BaiduError struct { | ||
ErrorCode int `json:"error_code"` | ||
ErrorMsg string `json:"error_msg"` | ||
} | ||
|
||
type BaiduChatResponse struct { | ||
Id string `json:"id"` | ||
Object string `json:"object"` | ||
Created int64 `json:"created"` | ||
Result string `json:"result"` | ||
IsTruncated bool `json:"is_truncated"` | ||
NeedClearHistory bool `json:"need_clear_history"` | ||
Usage Usage `json:"usage"` | ||
BaiduError | ||
} | ||
|
||
type BaiduChatStreamResponse struct { | ||
BaiduChatResponse | ||
SentenceId int `json:"sentence_id"` | ||
IsEnd bool `json:"is_end"` | ||
} | ||
|
||
func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest { | ||
messages := make([]BaiduMessage, 0, len(request.Messages)) | ||
for _, message := range request.Messages { | ||
messages = append(messages, BaiduMessage{ | ||
Role: message.Role, | ||
Content: message.Content, | ||
}) | ||
} | ||
return &BaiduChatRequest{ | ||
Messages: messages, | ||
Stream: request.Stream, | ||
} | ||
} | ||
|
||
func responseBaidu2OpenAI(response *BaiduChatResponse) *OpenAITextResponse { | ||
choice := OpenAITextResponseChoice{ | ||
Index: 0, | ||
Message: Message{ | ||
Role: "assistant", | ||
Content: response.Result, | ||
}, | ||
FinishReason: "stop", | ||
} | ||
fullTextResponse := OpenAITextResponse{ | ||
Id: response.Id, | ||
Object: "chat.completion", | ||
Created: response.Created, | ||
Choices: []OpenAITextResponseChoice{choice}, | ||
Usage: response.Usage, | ||
} | ||
return &fullTextResponse | ||
} | ||
|
||
func streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *ChatCompletionsStreamResponse { | ||
var choice ChatCompletionsStreamResponseChoice | ||
choice.Delta.Content = baiduResponse.Result | ||
choice.FinishReason = "stop" | ||
response := ChatCompletionsStreamResponse{ | ||
Id: baiduResponse.Id, | ||
Object: "chat.completion.chunk", | ||
Created: baiduResponse.Created, | ||
Model: "ernie-bot", | ||
Choices: []ChatCompletionsStreamResponseChoice{choice}, | ||
} | ||
return &response | ||
} | ||
|
||
func baiduStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { | ||
var usage Usage | ||
scanner := bufio.NewScanner(resp.Body) | ||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { | ||
if atEOF && len(data) == 0 { | ||
return 0, nil, nil | ||
} | ||
if i := strings.Index(string(data), "\n"); i >= 0 { | ||
return i + 1, data[0:i], nil | ||
} | ||
if atEOF { | ||
return len(data), data, nil | ||
} | ||
return 0, nil, nil | ||
}) | ||
dataChan := make(chan string) | ||
stopChan := make(chan bool) | ||
go func() { | ||
for scanner.Scan() { | ||
data := scanner.Text() | ||
if len(data) < 6 { // ignore blank line or wrong format | ||
continue | ||
} | ||
data = data[6:] | ||
dataChan <- data | ||
} | ||
stopChan <- true | ||
}() | ||
c.Writer.Header().Set("Content-Type", "text/event-stream") | ||
c.Writer.Header().Set("Cache-Control", "no-cache") | ||
c.Writer.Header().Set("Connection", "keep-alive") | ||
c.Writer.Header().Set("Transfer-Encoding", "chunked") | ||
c.Writer.Header().Set("X-Accel-Buffering", "no") | ||
c.Stream(func(w io.Writer) bool { | ||
select { | ||
case data := <-dataChan: | ||
var baiduResponse BaiduChatStreamResponse | ||
err := json.Unmarshal([]byte(data), &baiduResponse) | ||
if err != nil { | ||
common.SysError("error unmarshalling stream response: " + err.Error()) | ||
return true | ||
} | ||
usage.PromptTokens += baiduResponse.Usage.PromptTokens | ||
usage.CompletionTokens += baiduResponse.Usage.CompletionTokens | ||
usage.TotalTokens += baiduResponse.Usage.TotalTokens | ||
response := streamResponseBaidu2OpenAI(&baiduResponse) | ||
jsonResponse, err := json.Marshal(response) | ||
if err != nil { | ||
common.SysError("error marshalling stream response: " + err.Error()) | ||
return true | ||
} | ||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) | ||
return true | ||
case <-stopChan: | ||
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) | ||
return false | ||
} | ||
}) | ||
err := resp.Body.Close() | ||
if err != nil { | ||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||
} | ||
return nil, &usage | ||
} | ||
|
||
func baiduHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { | ||
var baiduResponse BaiduChatResponse | ||
responseBody, err := io.ReadAll(resp.Body) | ||
if err != nil { | ||
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | ||
} | ||
err = resp.Body.Close() | ||
if err != nil { | ||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||
} | ||
err = json.Unmarshal(responseBody, &baiduResponse) | ||
if err != nil { | ||
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||
} | ||
if baiduResponse.ErrorMsg != "" { | ||
return &OpenAIErrorWithStatusCode{ | ||
OpenAIError: OpenAIError{ | ||
Message: baiduResponse.ErrorMsg, | ||
Type: "baidu_error", | ||
Param: "", | ||
Code: baiduResponse.ErrorCode, | ||
}, | ||
StatusCode: resp.StatusCode, | ||
}, nil | ||
} | ||
fullTextResponse := responseBaidu2OpenAI(&baiduResponse) | ||
jsonResponse, err := json.Marshal(fullTextResponse) | ||
if err != nil { | ||
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | ||
} | ||
c.Writer.Header().Set("Content-Type", "application/json") | ||
c.Writer.WriteHeader(resp.StatusCode) | ||
_, err = c.Writer.Write(jsonResponse) | ||
return nil, &fullTextResponse.Usage | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters