Skip to content

Commit

Permalink
feat: add cloudflare ai gateway support for image & audio (songquanpe…
Browse files Browse the repository at this point in the history
…ng#607)

* Update channel-test.go

* Update relay-audio.go

* Update relay-image.go

* chore: using a util function

---------

Co-authored-by: JustSong <[email protected]>
  • Loading branch information
v1cc0 and songquanpeng authored Oct 22, 2023
1 parent 22980b4 commit 3b48363
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 18 deletions.
3 changes: 2 additions & 1 deletion controller/channel-test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"net/http"
"one-api/common"
"one-api/model"
"strconv"
"sync"
"time"

"github.com/gin-gonic/gin"
)

func testChannel(channel *model.Channel, request ChatRequest) (err error, openaiErr *OpenAIError) {
Expand Down
6 changes: 2 additions & 4 deletions controller/relay-audio.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,11 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"one-api/model"

"github.com/gin-gonic/gin"
)

func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
Expand Down Expand Up @@ -66,12 +65,11 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode

baseURL := common.ChannelBaseURLs[channelType]
requestURL := c.Request.URL.String()

if c.GetString("base_url") != "" {
baseURL = c.GetString("base_url")
}

fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType)
requestBody := c.Request.Body

req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
Expand Down
9 changes: 2 additions & 7 deletions controller/relay-image.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,11 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"one-api/model"

"github.com/gin-gonic/gin"
)

func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
Expand Down Expand Up @@ -61,16 +60,12 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
isModelMapped = true
}
}

baseURL := common.ChannelBaseURLs[channelType]
requestURL := c.Request.URL.String()

if c.GetString("base_url") != "" {
baseURL = c.GetString("base_url")
}

fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)

fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType)
var requestBody io.Reader
if isModelMapped {
jsonStr, err := json.Marshal(imageRequest)
Expand Down
7 changes: 1 addition & 6 deletions controller/relay-text.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,12 +118,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
if c.GetString("base_url") != "" {
baseURL = c.GetString("base_url")
}
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
if channelType == common.ChannelTypeOpenAI {
if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1"))
}
}
fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType)
switch apiType {
case APITypeOpenAI:
if channelType == common.ChannelTypeAzure {
Expand Down
10 changes: 10 additions & 0 deletions controller/relay-utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,3 +176,13 @@ func relayErrorHandler(resp *http.Response) (openAIErrorWithStatusCode *OpenAIEr
openAIErrorWithStatusCode.OpenAIError = textResponse.Error
return
}

func getFullRequestURL(baseURL string, requestURL string, channelType int) string {
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
if channelType == common.ChannelTypeOpenAI {
if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1"))
}
}
return fullRequestURL
}

0 comments on commit 3b48363

Please sign in to comment.