Skip to content

Commit

Permalink
feat: add support for /v1/engines/text-embedding-ada-002/embeddings (s…
Browse files Browse the repository at this point in the history
  • Loading branch information
playniuniu authored Jul 15, 2023
1 parent abc53cb commit 81c5901
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 4 deletions.
6 changes: 5 additions & 1 deletion controller/relay-text.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"one-api/model"
"strings"

"github.com/gin-gonic/gin"
)

func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
Expand All @@ -30,6 +31,9 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
if relayMode == RelayModeModerations && textRequest.Model == "" {
textRequest.Model = "text-moderation-latest"
}
if relayMode == RelayModeEmbeddings && textRequest.Model == "" {
textRequest.Model = c.Param("model")
}
// request validation
if textRequest.Model == "" {
return errorWrapper(errors.New("model is required"), "required_field_missing", http.StatusBadRequest)
Expand Down
5 changes: 4 additions & 1 deletion controller/relay.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@ package controller

import (
"fmt"
"github.com/gin-gonic/gin"
"net/http"
"one-api/common"
"strings"

"github.com/gin-gonic/gin"
)

type Message struct {
Expand Down Expand Up @@ -100,6 +101,8 @@ func Relay(c *gin.Context) {
relayMode = RelayModeCompletions
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/embeddings") {
relayMode = RelayModeEmbeddings
} else if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
relayMode = RelayModeEmbeddings
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
relayMode = RelayModeModerations
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
Expand Down
8 changes: 7 additions & 1 deletion middleware/distributor.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@ package middleware

import (
"fmt"
"github.com/gin-gonic/gin"
"net/http"
"one-api/common"
"one-api/model"
"strconv"
"strings"

"github.com/gin-gonic/gin"
)

type ModelRequest struct {
Expand Down Expand Up @@ -73,6 +74,11 @@ func Distribute() func(c *gin.Context) {
modelRequest.Model = "text-moderation-stable"
}
}
if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
if modelRequest.Model == "" {
modelRequest.Model = c.Param("model")
}
}
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model)
if err != nil {
message := "无可用渠道"
Expand Down
4 changes: 3 additions & 1 deletion router/relay-router.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
package router

import (
"github.com/gin-gonic/gin"
"one-api/controller"
"one-api/middleware"

"github.com/gin-gonic/gin"
)

func SetRelayRouter(router *gin.Engine) {
Expand All @@ -24,6 +25,7 @@ func SetRelayRouter(router *gin.Engine) {
relayV1Router.POST("/images/edits", controller.RelayNotImplemented)
relayV1Router.POST("/images/variations", controller.RelayNotImplemented)
relayV1Router.POST("/embeddings", controller.Relay)
relayV1Router.POST("/engines/:model/embeddings", controller.Relay)
relayV1Router.POST("/audio/transcriptions", controller.RelayNotImplemented)
relayV1Router.POST("/audio/translations", controller.RelayNotImplemented)
relayV1Router.GET("/files", controller.RelayNotImplemented)
Expand Down

0 comments on commit 81c5901

Please sign in to comment.