@@ -98,13 +98,17 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
98
98
var ratio float64
99
99
var modelRatio float64
100
100
//err := service.SensitiveWordsCheck(textRequest)
101
- promptTokens , err , sensitiveTrigger := getPromptTokens (textRequest , relayInfo )
102
101
103
- // count messages token error 计算promptTokens错误
104
- if err != nil {
105
- if sensitiveTrigger {
102
+ if constant . ShouldCheckPromptSensitive () {
103
+ err = checkRequestSensitive ( textRequest , relayInfo )
104
+ if err != nil {
106
105
return service .OpenAIErrorWrapperLocal (err , "sensitive_words_detected" , http .StatusBadRequest )
107
106
}
107
+ }
108
+
109
+ promptTokens , err := getPromptTokens (textRequest , relayInfo )
110
+ // count messages token error 计算promptTokens错误
111
+ if err != nil {
108
112
return service .OpenAIErrorWrapper (err , "count_token_messages_failed" , http .StatusInternalServerError )
109
113
}
110
114
@@ -128,15 +132,15 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
128
132
129
133
adaptor := GetAdaptor (relayInfo .ApiType )
130
134
if adaptor == nil {
131
- return service .OpenAIErrorWrapper (fmt .Errorf ("invalid api type: %d" , relayInfo .ApiType ), "invalid_api_type" , http .StatusBadRequest )
135
+ return service .OpenAIErrorWrapperLocal (fmt .Errorf ("invalid api type: %d" , relayInfo .ApiType ), "invalid_api_type" , http .StatusBadRequest )
132
136
}
133
137
adaptor .Init (relayInfo , * textRequest )
134
138
var requestBody io.Reader
135
139
if relayInfo .ApiType == relayconstant .APITypeOpenAI {
136
140
if isModelMapped {
137
141
jsonStr , err := json .Marshal (textRequest )
138
142
if err != nil {
139
- return service .OpenAIErrorWrapper (err , "marshal_text_request_failed" , http .StatusInternalServerError )
143
+ return service .OpenAIErrorWrapperLocal (err , "marshal_text_request_failed" , http .StatusInternalServerError )
140
144
}
141
145
requestBody = bytes .NewBuffer (jsonStr )
142
146
} else {
@@ -145,11 +149,11 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
145
149
} else {
146
150
convertedRequest , err := adaptor .ConvertRequest (c , relayInfo .RelayMode , textRequest )
147
151
if err != nil {
148
- return service .OpenAIErrorWrapper (err , "convert_request_failed" , http .StatusInternalServerError )
152
+ return service .OpenAIErrorWrapperLocal (err , "convert_request_failed" , http .StatusInternalServerError )
149
153
}
150
154
jsonData , err := json .Marshal (convertedRequest )
151
155
if err != nil {
152
- return service .OpenAIErrorWrapper (err , "json_marshal_failed" , http .StatusInternalServerError )
156
+ return service .OpenAIErrorWrapperLocal (err , "json_marshal_failed" , http .StatusInternalServerError )
153
157
}
154
158
requestBody = bytes .NewBuffer (jsonData )
155
159
}
@@ -182,26 +186,39 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
182
186
return nil
183
187
}
184
188
185
- func getPromptTokens (textRequest * dto.GeneralOpenAIRequest , info * relaycommon.RelayInfo ) (int , error , bool ) {
189
+ func getPromptTokens (textRequest * dto.GeneralOpenAIRequest , info * relaycommon.RelayInfo ) (int , error ) {
186
190
var promptTokens int
187
191
var err error
188
- var sensitiveTrigger bool
189
- checkSensitive := constant .ShouldCheckPromptSensitive ()
190
192
switch info .RelayMode {
191
193
case relayconstant .RelayModeChatCompletions :
192
- promptTokens , err , sensitiveTrigger = service .CountTokenChatRequest (* textRequest , textRequest .Model , checkSensitive )
194
+ promptTokens , err = service .CountTokenChatRequest (* textRequest , textRequest .Model )
193
195
case relayconstant .RelayModeCompletions :
194
- promptTokens , err , sensitiveTrigger = service .CountTokenInput (textRequest .Prompt , textRequest .Model , checkSensitive )
196
+ promptTokens , err = service .CountTokenInput (textRequest .Prompt , textRequest .Model )
195
197
case relayconstant .RelayModeModerations :
196
- promptTokens , err , sensitiveTrigger = service .CountTokenInput (textRequest .Input , textRequest .Model , checkSensitive )
198
+ promptTokens , err = service .CountTokenInput (textRequest .Input , textRequest .Model )
197
199
case relayconstant .RelayModeEmbeddings :
198
- promptTokens , err , sensitiveTrigger = service .CountTokenInput (textRequest .Input , textRequest .Model , checkSensitive )
200
+ promptTokens , err = service .CountTokenInput (textRequest .Input , textRequest .Model )
199
201
default :
200
202
err = errors .New ("unknown relay mode" )
201
203
promptTokens = 0
202
204
}
203
205
info .PromptTokens = promptTokens
204
- return promptTokens , err , sensitiveTrigger
206
+ return promptTokens , err
207
+ }
208
+
209
+ func checkRequestSensitive (textRequest * dto.GeneralOpenAIRequest , info * relaycommon.RelayInfo ) error {
210
+ var err error
211
+ switch info .RelayMode {
212
+ case relayconstant .RelayModeChatCompletions :
213
+ err = service .CheckSensitiveMessages (textRequest .Messages )
214
+ case relayconstant .RelayModeCompletions :
215
+ err = service .CheckSensitiveInput (textRequest .Prompt )
216
+ case relayconstant .RelayModeModerations :
217
+ err = service .CheckSensitiveInput (textRequest .Input )
218
+ case relayconstant .RelayModeEmbeddings :
219
+ err = service .CheckSensitiveInput (textRequest .Input )
220
+ }
221
+ return err
205
222
}
206
223
207
224
// 预扣费并返回用户剩余配额
0 commit comments