Skip to content

Commit 25b179e

Browse files
authored
Merge pull request CodePhiliaX#1599 from tmlx1990/ai
fix:1559 修复自定义AI不能使用的问题。
2 parents 4c51a19 + be4ecc2 commit 25b179e

File tree

7 files changed

+338
-188
lines changed

7 files changed

+338
-188
lines changed

chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/ChatController.java

+10-4
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ public SseEmitter distributeAISql(ChatQueryRequest queryRequest, SseEmitter sseE
235235
case CHAT2DBAI:
236236
return chatWithChat2dbAi(queryRequest, sseEmitter, uid);
237237
case RESTAI :
238-
return chatWithRestAi(queryRequest, sseEmitter);
238+
return chatWithRestAi(queryRequest, sseEmitter, uid);
239239
case FASTCHATAI:
240240
return chatWithFastChatAi(queryRequest, sseEmitter, uid);
241241
case AZUREAI :
@@ -261,9 +261,15 @@ public SseEmitter distributeAISql(ChatQueryRequest queryRequest, SseEmitter sseE
261261
* @param sseEmitter
262262
* @return
263263
*/
264-
private SseEmitter chatWithRestAi(ChatQueryRequest prompt, SseEmitter sseEmitter) {
265-
RestAIEventSourceListener eventSourceListener = new RestAIEventSourceListener(sseEmitter);
266-
RestAIClient.getInstance().restCompletions(buildPrompt(prompt), eventSourceListener);
264+
private SseEmitter chatWithRestAi(ChatQueryRequest queryRequest, SseEmitter sseEmitter, String uid) throws IOException {
265+
String prompt = buildPrompt(queryRequest);
266+
List<FastChatMessage> messages = getFastChatMessage(uid, prompt);
267+
268+
buildSseEmitter(sseEmitter, uid);
269+
270+
RestAIEventSourceListener restAIEventSourceListener = new RestAIEventSourceListener(sseEmitter);
271+
RestAIClient.getInstance().streamCompletions(messages, restAIEventSourceListener);
272+
LocalCache.CACHE.put(uid, JSONUtil.toJsonStr(messages), LocalCache.TIMEOUT);
267273
return sseEmitter;
268274
}
269275

chat2db-server/chat2db-server-web/chat2db-server-web-api/src/main/java/ai/chat2db/server/web/api/controller/ai/rest/client/RestAIClient.java

+26-7
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import ai.chat2db.server.web.api.util.ApplicationContextUtil;
77

88
import lombok.extern.slf4j.Slf4j;
9+
import org.apache.commons.lang3.StringUtils;
910

1011
/**
1112
* @author moji
@@ -19,6 +20,11 @@ public class RestAIClient {
1920
*/
2021
public static final String AI_SQL_SOURCE = "ai.sql.source";
2122

23+
/**
24+
* Customized AI interface KEY
25+
*/
26+
public static final String REST_AI_API_KEY = "rest.ai.apiKey";
27+
2228
/**
2329
* Customized AI interface address
2430
*/
@@ -29,17 +35,24 @@ public class RestAIClient {
2935
*/
3036
public static final String REST_AI_STREAM_OUT = "rest.ai.stream";
3137

32-
private static RestAiStreamClient REST_AI_STREAM_CLIENT;
38+
/**
39+
* Custom AI interface model
40+
*/
41+
public static final String REST_AI_MODEL = "rest.ai.model";
3342

34-
public static RestAiStreamClient getInstance() {
43+
44+
45+
private static RestAIStreamClient REST_AI_STREAM_CLIENT;
46+
47+
public static RestAIStreamClient getInstance() {
3548
if (REST_AI_STREAM_CLIENT != null) {
3649
return REST_AI_STREAM_CLIENT;
3750
} else {
3851
return singleton();
3952
}
4053
}
4154

42-
private static RestAiStreamClient singleton() {
55+
private static RestAIStreamClient singleton() {
4356
if (REST_AI_STREAM_CLIENT == null) {
4457
synchronized (RestAIClient.class) {
4558
if (REST_AI_STREAM_CLIENT == null) {
@@ -55,17 +68,23 @@ private static RestAiStreamClient singleton() {
5568
*/
5669
public static void refresh() {
5770
String apiUrl = "";
58-
Boolean stream = Boolean.TRUE;
71+
String apiKey = "";
72+
String model = "";
5973
ConfigService configService = ApplicationContextUtil.getBean(ConfigService.class);
6074
Config apiHostConfig = configService.find(REST_AI_URL).getData();
6175
if (apiHostConfig != null) {
6276
apiUrl = apiHostConfig.getContent();
6377
}
64-
Config config = configService.find(REST_AI_STREAM_OUT).getData();
78+
Config config = configService.find(REST_AI_API_KEY).getData();
6579
if (config != null) {
66-
stream = Boolean.valueOf(config.getContent());
80+
apiKey = config.getContent();
81+
}
82+
Config deployConfig = configService.find(REST_AI_MODEL).getData();
83+
if (deployConfig != null && StringUtils.isNotBlank(deployConfig.getContent())) {
84+
model = deployConfig.getContent();
6785
}
68-
REST_AI_STREAM_CLIENT = new RestAiStreamClient(apiUrl, stream);
86+
REST_AI_STREAM_CLIENT = RestAIStreamClient.builder().apiKey(apiKey).apiHost(apiUrl).model(model)
87+
.build();
6988
}
7089

7190
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
package ai.chat2db.server.web.api.controller.ai.rest.client;
2+
3+
import ai.chat2db.server.tools.common.exception.ParamBusinessException;
4+
import ai.chat2db.server.web.api.controller.ai.fastchat.interceptor.FastChatHeaderAuthorizationInterceptor;
5+
import ai.chat2db.server.web.api.controller.ai.fastchat.model.FastChatCompletionsOptions;
6+
import ai.chat2db.server.web.api.controller.ai.fastchat.model.FastChatMessage;
7+
import cn.hutool.http.ContentType;
8+
import com.fasterxml.jackson.databind.DeserializationFeature;
9+
import com.fasterxml.jackson.databind.ObjectMapper;
10+
import lombok.Getter;
11+
import lombok.extern.slf4j.Slf4j;
12+
import okhttp3.MediaType;
13+
import okhttp3.OkHttpClient;
14+
import okhttp3.Request;
15+
import okhttp3.RequestBody;
16+
import okhttp3.sse.EventSource;
17+
import okhttp3.sse.EventSourceListener;
18+
import okhttp3.sse.EventSources;
19+
import org.apache.commons.collections4.CollectionUtils;
20+
import org.jetbrains.annotations.NotNull;
21+
22+
import java.util.List;
23+
import java.util.Objects;
24+
import java.util.concurrent.TimeUnit;
25+
26+
/**
27+
* Custom AI interface client
28+
* @author moji
29+
*/
30+
@Slf4j
31+
public class RestAIStreamClient {
32+
/**
33+
* apikey
34+
*/
35+
@Getter
36+
@NotNull
37+
private String apiKey;
38+
39+
/**
40+
* apiHost
41+
*/
42+
@Getter
43+
@NotNull
44+
private String apiHost;
45+
46+
/**
47+
* model
48+
*/
49+
@Getter
50+
private String model;
51+
/**
52+
* okHttpClient
53+
*/
54+
@Getter
55+
private OkHttpClient okHttpClient;
56+
57+
/**
58+
* Construct instance object
59+
*
60+
* @param builder
61+
*/
62+
public RestAIStreamClient(Builder builder) {
63+
this.apiKey = builder.apiKey;
64+
this.apiHost = builder.apiHost;
65+
this.model = builder.model;
66+
this.okHttpClient = new OkHttpClient
67+
.Builder()
68+
.addInterceptor(new FastChatHeaderAuthorizationInterceptor(this.apiKey))
69+
.connectTimeout(10, TimeUnit.SECONDS)
70+
.writeTimeout(50, TimeUnit.SECONDS)
71+
.readTimeout(50, TimeUnit.SECONDS)
72+
.build();
73+
}
74+
75+
/**
76+
* structure
77+
*
78+
* @return
79+
*/
80+
public static RestAIStreamClient.Builder builder() {
81+
return new RestAIStreamClient.Builder();
82+
}
83+
84+
/**
85+
* builder
86+
*/
87+
public static final class Builder {
88+
private String apiKey;
89+
90+
private String apiHost;
91+
92+
private String model;
93+
94+
95+
/**
96+
* OkhttpClient
97+
*/
98+
private OkHttpClient okHttpClient;
99+
100+
public Builder() {
101+
}
102+
103+
public RestAIStreamClient.Builder apiKey(String apiKeyValue) {
104+
this.apiKey = apiKeyValue;
105+
return this;
106+
}
107+
108+
/**
109+
* @param apiHostValue
110+
* @return
111+
*/
112+
public RestAIStreamClient.Builder apiHost(String apiHostValue) {
113+
this.apiHost = apiHostValue;
114+
return this;
115+
}
116+
117+
/**
118+
* @param modelValue
119+
* @return
120+
*/
121+
public RestAIStreamClient.Builder model(String modelValue) {
122+
this.model = modelValue;
123+
return this;
124+
}
125+
126+
127+
public RestAIStreamClient.Builder okHttpClient(OkHttpClient val) {
128+
this.okHttpClient = val;
129+
return this;
130+
}
131+
132+
public RestAIStreamClient build() {
133+
return new RestAIStreamClient(this);
134+
}
135+
136+
}
137+
138+
139+
/**
140+
* Q&A interface stream form
141+
*
142+
* @param chatMessages
143+
* @param eventSourceListener
144+
*/
145+
public void streamCompletions(List<FastChatMessage> chatMessages, EventSourceListener eventSourceListener) {
146+
if (CollectionUtils.isEmpty(chatMessages)) {
147+
log.error("param error:Rest AI Prompt cannot be empty");
148+
throw new ParamBusinessException("prompt");
149+
}
150+
if (Objects.isNull(eventSourceListener)) {
151+
log.error("param error:RestAIEventSourceListener cannot be empty");
152+
throw new ParamBusinessException();
153+
}
154+
log.info("Rest AI, prompt:{}", chatMessages.get(chatMessages.size() - 1).getContent());
155+
try {
156+
157+
FastChatCompletionsOptions chatCompletionsOptions = new FastChatCompletionsOptions(chatMessages);
158+
chatCompletionsOptions.setStream(true);
159+
chatCompletionsOptions.setModel(this.model);
160+
161+
EventSource.Factory factory = EventSources.createFactory(this.okHttpClient);
162+
ObjectMapper mapper = new ObjectMapper();
163+
mapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
164+
String requestBody = mapper.writeValueAsString(chatCompletionsOptions);
165+
Request request = new Request.Builder()
166+
.url(apiHost)
167+
.post(RequestBody.create(MediaType.parse(ContentType.JSON.getValue()), requestBody))
168+
.build();
169+
//Create event
170+
EventSource eventSource = factory.newEventSource(request, eventSourceListener);
171+
log.info("finish invoking rest ai");
172+
} catch (Exception e) {
173+
log.error("rest ai error", e);
174+
eventSourceListener.onFailure(null, e, null);
175+
throw new ParamBusinessException();
176+
}
177+
}
178+
179+
180+
}

0 commit comments

Comments
 (0)