Skip to content

Commit c243cd5

Browse files
feat: 支持 ollama 的 embedding 接口 (songquanpeng#1221)
* 增加ollama的embedding接口 * chore: fix function name --------- Co-authored-by: JustSong <[email protected]>
1 parent e96b173 commit c243cd5

File tree

3 files changed

+86
-7
lines changed

3 files changed

+86
-7
lines changed

relay/channel/ollama/adaptor.go

+14-4
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,14 @@ package ollama
33
import (
44
"errors"
55
"fmt"
6+
"io"
7+
"net/http"
8+
69
"github.com/gin-gonic/gin"
710
"github.com/songquanpeng/one-api/relay/channel"
811
"github.com/songquanpeng/one-api/relay/constant"
912
"github.com/songquanpeng/one-api/relay/model"
1013
"github.com/songquanpeng/one-api/relay/util"
11-
"io"
12-
"net/http"
1314
)
1415

1516
type Adaptor struct {
@@ -22,6 +23,9 @@ func (a *Adaptor) Init(meta *util.RelayMeta) {
2223
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
2324
// https://github.com/ollama/ollama/blob/main/docs/api.md
2425
fullRequestURL := fmt.Sprintf("%s/api/chat", meta.BaseURL)
26+
if meta.Mode == constant.RelayModeEmbeddings {
27+
fullRequestURL = fmt.Sprintf("%s/api/embeddings", meta.BaseURL)
28+
}
2529
return fullRequestURL, nil
2630
}
2731

@@ -37,7 +41,8 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
3741
}
3842
switch relayMode {
3943
case constant.RelayModeEmbeddings:
40-
return nil, errors.New("not supported")
44+
ollamaEmbeddingRequest := ConvertEmbeddingRequest(*request)
45+
return ollamaEmbeddingRequest, nil
4146
default:
4247
return ConvertRequest(*request), nil
4348
}
@@ -51,7 +56,12 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.Rel
5156
if meta.IsStream {
5257
err, usage = StreamHandler(c, resp)
5358
} else {
54-
err, usage = Handler(c, resp)
59+
switch meta.Mode {
60+
case constant.RelayModeEmbeddings:
61+
err, usage = EmbeddingHandler(c, resp)
62+
default:
63+
err, usage = Handler(c, resp)
64+
}
5565
}
5666
return
5767
}

relay/channel/ollama/main.go

+62-3
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,17 @@ import (
55
"context"
66
"encoding/json"
77
"fmt"
8+
"io"
9+
"net/http"
10+
"strings"
11+
812
"github.com/gin-gonic/gin"
913
"github.com/songquanpeng/one-api/common"
1014
"github.com/songquanpeng/one-api/common/helper"
1115
"github.com/songquanpeng/one-api/common/logger"
1216
"github.com/songquanpeng/one-api/relay/channel/openai"
1317
"github.com/songquanpeng/one-api/relay/constant"
1418
"github.com/songquanpeng/one-api/relay/model"
15-
"io"
16-
"net/http"
17-
"strings"
1819
)
1920

2021
func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
@@ -139,6 +140,64 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
139140
return nil, &usage
140141
}
141142

143+
func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingRequest {
144+
return &EmbeddingRequest{
145+
Model: request.Model,
146+
Prompt: strings.Join(request.ParseInput(), " "),
147+
}
148+
}
149+
150+
func EmbeddingHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
151+
var ollamaResponse EmbeddingResponse
152+
err := json.NewDecoder(resp.Body).Decode(&ollamaResponse)
153+
if err != nil {
154+
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
155+
}
156+
157+
err = resp.Body.Close()
158+
if err != nil {
159+
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
160+
}
161+
162+
if ollamaResponse.Error != "" {
163+
return &model.ErrorWithStatusCode{
164+
Error: model.Error{
165+
Message: ollamaResponse.Error,
166+
Type: "ollama_error",
167+
Param: "",
168+
Code: "ollama_error",
169+
},
170+
StatusCode: resp.StatusCode,
171+
}, nil
172+
}
173+
174+
fullTextResponse := embeddingResponseOllama2OpenAI(&ollamaResponse)
175+
jsonResponse, err := json.Marshal(fullTextResponse)
176+
if err != nil {
177+
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
178+
}
179+
c.Writer.Header().Set("Content-Type", "application/json")
180+
c.Writer.WriteHeader(resp.StatusCode)
181+
_, err = c.Writer.Write(jsonResponse)
182+
return nil, &fullTextResponse.Usage
183+
}
184+
185+
func embeddingResponseOllama2OpenAI(response *EmbeddingResponse) *openai.EmbeddingResponse {
186+
openAIEmbeddingResponse := openai.EmbeddingResponse{
187+
Object: "list",
188+
Data: make([]openai.EmbeddingResponseItem, 0, 1),
189+
Model: "text-embedding-v1",
190+
Usage: model.Usage{TotalTokens: 0},
191+
}
192+
193+
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, openai.EmbeddingResponseItem{
194+
Object: `embedding`,
195+
Index: 0,
196+
Embedding: response.Embedding,
197+
})
198+
return &openAIEmbeddingResponse
199+
}
200+
142201
func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
143202
ctx := context.TODO()
144203
var ollamaResponse ChatResponse

relay/channel/ollama/model.go

+10
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,13 @@ type ChatResponse struct {
3535
EvalDuration int `json:"eval_duration,omitempty"`
3636
Error string `json:"error,omitempty"`
3737
}
38+
39+
type EmbeddingRequest struct {
40+
Model string `json:"model"`
41+
Prompt string `json:"prompt"`
42+
}
43+
44+
type EmbeddingResponse struct {
45+
Error string `json:"error,omitempty"`
46+
Embedding []float64 `json:"embedding,omitempty"`
47+
}

0 commit comments

Comments
 (0)