Skip to content

Commit bfbbe67

Browse files
committed
refactor: 重构敏感词
1 parent 0867d36 commit bfbbe67

12 files changed

+158
-91
lines changed

common/str.go

+30
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,13 @@
11
package common
22

3+
import (
4+
"bytes"
5+
"fmt"
6+
goahocorasick "github.com/anknown/ahocorasick"
7+
"one-api/constant"
8+
"strings"
9+
)
10+
311
func SundaySearch(text string, pattern string) bool {
412
// 计算偏移表
513
offset := make(map[rune]int)
@@ -48,3 +56,25 @@ func RemoveDuplicate(s []string) []string {
4856
}
4957
return result
5058
}
59+
60+
func InitAc() *goahocorasick.Machine {
61+
m := new(goahocorasick.Machine)
62+
dict := readRunes()
63+
if err := m.Build(dict); err != nil {
64+
fmt.Println(err)
65+
return nil
66+
}
67+
return m
68+
}
69+
70+
func readRunes() [][]rune {
71+
var dict [][]rune
72+
73+
for _, word := range constant.SensitiveWords {
74+
word = strings.ToLower(word)
75+
l := bytes.TrimSpace([]byte(word))
76+
dict = append(dict, bytes.Runes(l))
77+
}
78+
79+
return dict
80+
}

constant/sensitive.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ var StreamCacheQueueLength = 0
1616
// SensitiveWords 敏感词
1717
// var SensitiveWords []string
1818
var SensitiveWords = []string{
19-
"test",
19+
"test_sensitive",
2020
}
2121

2222
func SensitiveWordsToString() string {

relay/channel/claude/relay-claude.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,7 @@ func claudeHandler(requestMode int, c *gin.Context, resp *http.Response, promptT
370370
}, nil
371371
}
372372
fullTextResponse := ResponseClaude2OpenAI(requestMode, &claudeResponse)
373-
completionTokens, err, _ := service.CountTokenText(claudeResponse.Completion, model, false)
373+
completionTokens, err := service.CountTokenText(claudeResponse.Completion, model)
374374
if err != nil {
375375
return service.OpenAIErrorWrapper(err, "count_token_text_failed", http.StatusInternalServerError), nil
376376
}

relay/channel/gemini/relay-gemini.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, mo
256256
}, nil
257257
}
258258
fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse)
259-
completionTokens, _, _ := service.CountTokenText(geminiResponse.GetResponseText(), model, false)
259+
completionTokens, _ := service.CountTokenText(geminiResponse.GetResponseText(), model)
260260
usage := dto.Usage{
261261
PromptTokens: promptTokens,
262262
CompletionTokens: completionTokens,

relay/channel/openai/relay-openai.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model
190190
if simpleResponse.Usage.TotalTokens == 0 {
191191
completionTokens := 0
192192
for _, choice := range simpleResponse.Choices {
193-
ctkm, _, _ := service.CountTokenText(string(choice.Message.Content), model, false)
193+
ctkm, _ := service.CountTokenText(string(choice.Message.Content), model)
194194
completionTokens += ctkm
195195
}
196196
simpleResponse.Usage = dto.Usage{

relay/channel/palm/relay-palm.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model st
156156
}, nil
157157
}
158158
fullTextResponse := responsePaLM2OpenAI(&palmResponse)
159-
completionTokens, _, _ := service.CountTokenText(palmResponse.Candidates[0].Content, model, false)
159+
completionTokens, _ := service.CountTokenText(palmResponse.Candidates[0].Content, model)
160160
usage := dto.Usage{
161161
PromptTokens: promptTokens,
162162
CompletionTokens: completionTokens,

relay/relay-audio.go

+8-2
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,13 @@ func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
5555
promptTokens := 0
5656
preConsumedTokens := common.PreConsumedQuota
5757
if strings.HasPrefix(audioRequest.Model, "tts-1") {
58-
promptTokens, err, _ = service.CountAudioToken(audioRequest.Input, audioRequest.Model, constant.ShouldCheckPromptSensitive())
58+
if constant.ShouldCheckPromptSensitive() {
59+
err = service.CheckSensitiveInput(audioRequest.Input)
60+
if err != nil {
61+
return service.OpenAIErrorWrapper(err, "sensitive_words_detected", http.StatusBadRequest)
62+
}
63+
}
64+
promptTokens, err = service.CountAudioToken(audioRequest.Input, audioRequest.Model)
5965
if err != nil {
6066
return service.OpenAIErrorWrapper(err, "count_audio_token_failed", http.StatusInternalServerError)
6167
}
@@ -178,7 +184,7 @@ func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
178184
if strings.HasPrefix(audioRequest.Model, "tts-1") {
179185
quota = promptTokens
180186
} else {
181-
quota, err, _ = service.CountAudioToken(audioResponse.Text, audioRequest.Model, false)
187+
quota, err = service.CountAudioToken(audioResponse.Text, audioRequest.Model)
182188
}
183189
quota = int(float64(quota) * ratio)
184190
if ratio != 0 && quota <= 0 {

relay/relay-image.go

+8
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"io"
1111
"net/http"
1212
"one-api/common"
13+
"one-api/constant"
1314
"one-api/dto"
1415
"one-api/model"
1516
relaycommon "one-api/relay/common"
@@ -47,6 +48,13 @@ func RelayImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusC
4748
return service.OpenAIErrorWrapper(errors.New("prompt is required"), "required_field_missing", http.StatusBadRequest)
4849
}
4950

51+
if constant.ShouldCheckPromptSensitive() {
52+
err = service.CheckSensitiveInput(imageRequest.Prompt)
53+
if err != nil {
54+
return service.OpenAIErrorWrapper(err, "sensitive_words_detected", http.StatusBadRequest)
55+
}
56+
}
57+
5058
if strings.Contains(imageRequest.Size, "×") {
5159
return service.OpenAIErrorWrapper(errors.New("size an unexpected error occurred in the parameter, please use 'x' instead of the multiplication sign '×'"), "invalid_field_value", http.StatusBadRequest)
5260
}

relay/relay-text.go

+33-16
Original file line numberDiff line numberDiff line change
@@ -98,13 +98,17 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
9898
var ratio float64
9999
var modelRatio float64
100100
//err := service.SensitiveWordsCheck(textRequest)
101-
promptTokens, err, sensitiveTrigger := getPromptTokens(textRequest, relayInfo)
102101

103-
// count messages token error 计算promptTokens错误
104-
if err != nil {
105-
if sensitiveTrigger {
102+
if constant.ShouldCheckPromptSensitive() {
103+
err = checkRequestSensitive(textRequest, relayInfo)
104+
if err != nil {
106105
return service.OpenAIErrorWrapperLocal(err, "sensitive_words_detected", http.StatusBadRequest)
107106
}
107+
}
108+
109+
promptTokens, err := getPromptTokens(textRequest, relayInfo)
110+
// count messages token error 计算promptTokens错误
111+
if err != nil {
108112
return service.OpenAIErrorWrapper(err, "count_token_messages_failed", http.StatusInternalServerError)
109113
}
110114

@@ -128,15 +132,15 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
128132

129133
adaptor := GetAdaptor(relayInfo.ApiType)
130134
if adaptor == nil {
131-
return service.OpenAIErrorWrapper(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
135+
return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
132136
}
133137
adaptor.Init(relayInfo, *textRequest)
134138
var requestBody io.Reader
135139
if relayInfo.ApiType == relayconstant.APITypeOpenAI {
136140
if isModelMapped {
137141
jsonStr, err := json.Marshal(textRequest)
138142
if err != nil {
139-
return service.OpenAIErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
143+
return service.OpenAIErrorWrapperLocal(err, "marshal_text_request_failed", http.StatusInternalServerError)
140144
}
141145
requestBody = bytes.NewBuffer(jsonStr)
142146
} else {
@@ -145,11 +149,11 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
145149
} else {
146150
convertedRequest, err := adaptor.ConvertRequest(c, relayInfo.RelayMode, textRequest)
147151
if err != nil {
148-
return service.OpenAIErrorWrapper(err, "convert_request_failed", http.StatusInternalServerError)
152+
return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError)
149153
}
150154
jsonData, err := json.Marshal(convertedRequest)
151155
if err != nil {
152-
return service.OpenAIErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError)
156+
return service.OpenAIErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError)
153157
}
154158
requestBody = bytes.NewBuffer(jsonData)
155159
}
@@ -182,26 +186,39 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
182186
return nil
183187
}
184188

185-
func getPromptTokens(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (int, error, bool) {
189+
func getPromptTokens(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (int, error) {
186190
var promptTokens int
187191
var err error
188-
var sensitiveTrigger bool
189-
checkSensitive := constant.ShouldCheckPromptSensitive()
190192
switch info.RelayMode {
191193
case relayconstant.RelayModeChatCompletions:
192-
promptTokens, err, sensitiveTrigger = service.CountTokenChatRequest(*textRequest, textRequest.Model, checkSensitive)
194+
promptTokens, err = service.CountTokenChatRequest(*textRequest, textRequest.Model)
193195
case relayconstant.RelayModeCompletions:
194-
promptTokens, err, sensitiveTrigger = service.CountTokenInput(textRequest.Prompt, textRequest.Model, checkSensitive)
196+
promptTokens, err = service.CountTokenInput(textRequest.Prompt, textRequest.Model)
195197
case relayconstant.RelayModeModerations:
196-
promptTokens, err, sensitiveTrigger = service.CountTokenInput(textRequest.Input, textRequest.Model, checkSensitive)
198+
promptTokens, err = service.CountTokenInput(textRequest.Input, textRequest.Model)
197199
case relayconstant.RelayModeEmbeddings:
198-
promptTokens, err, sensitiveTrigger = service.CountTokenInput(textRequest.Input, textRequest.Model, checkSensitive)
200+
promptTokens, err = service.CountTokenInput(textRequest.Input, textRequest.Model)
199201
default:
200202
err = errors.New("unknown relay mode")
201203
promptTokens = 0
202204
}
203205
info.PromptTokens = promptTokens
204-
return promptTokens, err, sensitiveTrigger
206+
return promptTokens, err
207+
}
208+
209+
func checkRequestSensitive(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) error {
210+
var err error
211+
switch info.RelayMode {
212+
case relayconstant.RelayModeChatCompletions:
213+
err = service.CheckSensitiveMessages(textRequest.Messages)
214+
case relayconstant.RelayModeCompletions:
215+
err = service.CheckSensitiveInput(textRequest.Prompt)
216+
case relayconstant.RelayModeModerations:
217+
err = service.CheckSensitiveInput(textRequest.Input)
218+
case relayconstant.RelayModeEmbeddings:
219+
err = service.CheckSensitiveInput(textRequest.Input)
220+
}
221+
return err
205222
}
206223

207224
// 预扣费并返回用户剩余配额

service/sensitive.go

+51-26
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,68 @@
11
package service
22

33
import (
4-
"bytes"
4+
"errors"
55
"fmt"
6-
"github.com/anknown/ahocorasick"
6+
"one-api/common"
77
"one-api/constant"
8+
"one-api/dto"
89
"strings"
910
)
1011

12+
func CheckSensitiveMessages(messages []dto.Message) error {
13+
for _, message := range messages {
14+
if len(message.Content) > 0 {
15+
if message.IsStringContent() {
16+
stringContent := message.StringContent()
17+
if ok, words := SensitiveWordContains(stringContent); ok {
18+
return errors.New("sensitive words: " + strings.Join(words, ","))
19+
}
20+
}
21+
} else {
22+
arrayContent := message.ParseContent()
23+
for _, m := range arrayContent {
24+
if m.Type == "image_url" {
25+
// TODO: check image url
26+
} else {
27+
if ok, words := SensitiveWordContains(m.Text); ok {
28+
return errors.New("sensitive words: " + strings.Join(words, ","))
29+
}
30+
}
31+
}
32+
}
33+
}
34+
return nil
35+
}
36+
37+
func CheckSensitiveText(text string) error {
38+
if ok, words := SensitiveWordContains(text); ok {
39+
return errors.New("sensitive words: " + strings.Join(words, ","))
40+
}
41+
return nil
42+
}
43+
44+
func CheckSensitiveInput(input any) error {
45+
switch v := input.(type) {
46+
case string:
47+
return CheckSensitiveText(v)
48+
case []string:
49+
text := ""
50+
for _, s := range v {
51+
text += s
52+
}
53+
return CheckSensitiveText(text)
54+
}
55+
return CheckSensitiveText(fmt.Sprintf("%v", input))
56+
}
57+
1158
// SensitiveWordContains 是否包含敏感词,返回是否包含敏感词和敏感词列表
1259
func SensitiveWordContains(text string) (bool, []string) {
1360
if len(constant.SensitiveWords) == 0 {
1461
return false, nil
1562
}
1663
checkText := strings.ToLower(text)
1764
// 构建一个AC自动机
18-
m := initAc()
65+
m := common.InitAc()
1966
hits := m.MultiPatternSearch([]rune(checkText), false)
2067
if len(hits) > 0 {
2168
words := make([]string, 0)
@@ -33,7 +80,7 @@ func SensitiveWordReplace(text string, returnImmediately bool) (bool, []string,
3380
return false, nil, text
3481
}
3582
checkText := strings.ToLower(text)
36-
m := initAc()
83+
m := common.InitAc()
3784
hits := m.MultiPatternSearch([]rune(checkText), returnImmediately)
3885
if len(hits) > 0 {
3986
words := make([]string, 0)
@@ -47,25 +94,3 @@ func SensitiveWordReplace(text string, returnImmediately bool) (bool, []string,
4794
}
4895
return false, nil, text
4996
}
50-
51-
func initAc() *goahocorasick.Machine {
52-
m := new(goahocorasick.Machine)
53-
dict := readRunes()
54-
if err := m.Build(dict); err != nil {
55-
fmt.Println(err)
56-
return nil
57-
}
58-
return m
59-
}
60-
61-
func readRunes() [][]rune {
62-
var dict [][]rune
63-
64-
for _, word := range constant.SensitiveWords {
65-
word = strings.ToLower(word)
66-
l := bytes.TrimSpace([]byte(word))
67-
dict = append(dict, bytes.Runes(l))
68-
}
69-
70-
return dict
71-
}

0 commit comments

Comments
 (0)