Skip to content

Commit

Permalink
chroe: add ChatModelListener support for ZhipuAi chat model and strea…
Browse files Browse the repository at this point in the history
…ming chat model (langchain4j#1378)

## Issue
langchain4j#199 

## Change

add ChatModelListener support for ZhipuAi:
- chat model
- streaming chat model

## General checklist
<!-- Please double-check the following points and mark them like this:
[X] -->
- [x] There are no breaking changes
- [x] I have added unit and integration tests for my change
- [x] I have manually run all the unit and integration tests in the
module I have added/changed, and they are all green
- [x] I have manually run all the unit and integration tests in the
[core](https://github.com/langchain4j/langchain4j/tree/main/langchain4j-core)
and
[main](https://github.com/langchain4j/langchain4j/tree/main/langchain4j)
modules, and they are all green
<!-- Before adding documentation and example(s) (below), please wait
until the PR is reviewed and approved. -->
- [ ] I have added/updated the
[documentation](https://github.com/langchain4j/langchain4j/tree/main/docs/docs)
- [ ] I have added an example in the [examples
repo](https://github.com/langchain4j/langchain4j-examples) (only for
"big" features)
- [ ] I have added/updated [Spring Boot
starter(s)](https://github.com/langchain4j/langchain4j-spring) (if
applicable)
  • Loading branch information
1402564807 authored Jul 1, 2024
1 parent 544526b commit fc78f3f
Show file tree
Hide file tree
Showing 10 changed files with 630 additions and 91 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import java.util.HashMap;
import java.util.Map;

import static com.google.common.net.HttpHeaders.ACCEPT;
import static com.google.common.net.HttpHeaders.AUTHORIZATION;
import static dev.langchain4j.internal.Utils.getOrDefault;
import static java.lang.System.currentTimeMillis;
import static java.nio.charset.StandardCharsets.UTF_8;
Expand Down Expand Up @@ -57,8 +59,8 @@ public Response intercept(Chain chain) throws IOException {
String token = getOrDefault(cache.getIfPresent(this.apiKey), generateToken());
Request request = chain.request()
.newBuilder()
.addHeader("Authorization", "Bearer " + token)
.removeHeader("Accept")
.addHeader(AUTHORIZATION, "Bearer " + token)
.removeHeader(ACCEPT)
.build();
return chain.proceed(request);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,16 @@
import dev.langchain4j.data.message.SystemMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.data.message.*;
import dev.langchain4j.model.chat.listener.ChatModelRequest;
import dev.langchain4j.model.chat.listener.ChatModelResponse;
import dev.langchain4j.model.output.FinishReason;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import dev.langchain4j.model.zhipu.chat.*;
import dev.langchain4j.model.zhipu.embedding.EmbeddingResponse;
import dev.langchain4j.model.zhipu.shared.ErrorResponse;
import dev.langchain4j.model.zhipu.shared.Usage;
import lombok.Cleanup;
import okhttp3.ResponseBody;

import java.io.IOException;
Expand All @@ -27,13 +31,16 @@

class DefaultZhipuAiHelper {

public static List<Embedding> toEmbed(List<EmbeddingResponse> response) {
static final String FINISH_REASON_SENSITIVE = "sensitive";
static final String FINISH_REASON_OTHER = "other";

static List<Embedding> toEmbed(List<EmbeddingResponse> response) {
return response.stream()
.map(zhipuAiEmbedding -> Embedding.from(zhipuAiEmbedding.getEmbedding()))
.collect(Collectors.toList());
}

public static List<Tool> toTools(List<ToolSpecification> toolSpecifications) {
static List<Tool> toTools(List<ToolSpecification> toolSpecifications) {
return toolSpecifications.stream()
.map(toolSpecification -> Tool.from(toFunction(toolSpecification)))
.collect(Collectors.toList());
Expand All @@ -58,7 +65,7 @@ private static Parameters toFunctionParameters(ToolParameters toolParameters) {
}


public static List<Message> toZhipuAiMessages(List<ChatMessage> messages) {
static List<Message> toZhipuAiMessages(List<ChatMessage> messages) {
return messages.stream()
.map(DefaultZhipuAiHelper::toZhipuAiMessage)
.collect(Collectors.toList());
Expand Down Expand Up @@ -117,7 +124,7 @@ private static Message toZhipuAiMessage(ChatMessage message) {
throw illegalArgument("Unknown message type: " + message.type());
}

public static AiMessage aiMessageFrom(ChatCompletionResponse response) {
static AiMessage aiMessageFrom(ChatCompletionResponse response) {
AssistantMessage message = response.getChoices().get(0).getMessage();
if (isNullOrEmpty(message.getToolCalls())) {
return AiMessage.from(message.getContent());
Expand All @@ -126,7 +133,7 @@ public static AiMessage aiMessageFrom(ChatCompletionResponse response) {
return AiMessage.from(specificationsFrom(message.getToolCalls()));
}

public static List<ToolExecutionRequest> specificationsFrom(List<ToolCall> toolCalls) {
static List<ToolExecutionRequest> specificationsFrom(List<ToolCall> toolCalls) {
List<ToolExecutionRequest> specifications = new ArrayList<>(toolCalls.size());
for (ToolCall toolCall : toolCalls) {
specifications.add(
Expand All @@ -140,7 +147,7 @@ public static List<ToolExecutionRequest> specificationsFrom(List<ToolCall> toolC
return specifications;
}

public static Usage getEmbeddingUsage(List<EmbeddingResponse> responses) {
static Usage getEmbeddingUsage(List<EmbeddingResponse> responses) {
Usage tokenUsage = Usage.builder()
.completionTokens(0)
.promptTokens(0)
Expand All @@ -154,7 +161,7 @@ public static Usage getEmbeddingUsage(List<EmbeddingResponse> responses) {
}


public static TokenUsage tokenUsageFrom(Usage zhipuUsage) {
static TokenUsage tokenUsageFrom(Usage zhipuUsage) {
if (zhipuUsage == null) {
return null;
}
Expand All @@ -165,39 +172,98 @@ public static TokenUsage tokenUsageFrom(Usage zhipuUsage) {
);
}

public static ChatCompletionResponse toChatErrorResponse(retrofit2.Response<?> retrofitResponse) throws IOException {
try (ResponseBody errorBody = retrofitResponse.errorBody()) {
return ChatCompletionResponse.builder()
.choices(Collections.singletonList(toChatErrorChoice(errorBody)))
.usage(Usage.builder().build())
.build();
}
static ChatCompletionResponse toChatErrorResponse(Object object) {
return ChatCompletionResponse.builder()
.choices(Collections.singletonList(toChatErrorChoice(object)))
.usage(Usage.builder().build())
.build();
}

/**
* error code see <a href="https://open.bigmodel.cn/dev/api#error-code-v3">error codes document</a>
*/
private static ChatCompletionChoice toChatErrorChoice(ResponseBody errorBody) throws IOException {
private static ChatCompletionChoice toChatErrorChoice(Object object) {
if (object instanceof Throwable) {
Throwable throwable = (Throwable) object;
return ChatCompletionChoice.builder()
.message(AssistantMessage.builder().content(throwable.getMessage()).build())
.finishReason(FINISH_REASON_OTHER)
.build();
}
@Cleanup ResponseBody errorBody = ((retrofit2.Response<?>) object).errorBody();

if (errorBody == null) {
return ChatCompletionChoice.builder()
.finishReason("other")
.finishReason(FINISH_REASON_OTHER)
.build();
}
ErrorResponse errorResponse = Json.fromJson(errorBody.string(), ErrorResponse.class);
// 1301: 系统检测到输入或生成内容可能包含不安全或敏感内容,请您避免输入易产生敏感内容的提示语,感谢您的配合
if ("1301".equals(errorResponse.getError().get("code"))) {
ErrorResponse errorResponse;
try {
errorResponse = Json.fromJson(errorBody.string(), ErrorResponse.class);
String code = errorResponse.getError().get("code");
return ChatCompletionChoice.builder()
.message(AssistantMessage.builder().content(errorResponse.getError().get("message")).build())
.finishReason("sensitive")
.finishReason(getFinishReason(code))
.build();
} catch (IOException e) {
return ChatCompletionChoice.builder()
.message(AssistantMessage.builder().content(e.getMessage()).build())
.finishReason(FINISH_REASON_OTHER)
.build();
}
return ChatCompletionChoice.builder()
.message(AssistantMessage.builder().content(errorResponse.getError().get("message")).build())
.finishReason("other")
}

static String getFinishReason(Object o) {
if (o instanceof String) {
// 1301: 系统检测到输入或生成内容可能包含不安全或敏感内容,请您避免输入易产生敏感内容的提示语,感谢您的配合
if ("1301".equals(o)) {
return FINISH_REASON_SENSITIVE;
}
}
if (o instanceof ZhipuAiException) {
ZhipuAiException exception = (ZhipuAiException) o;
if ("1301".equals(exception.getCode())) {
return FINISH_REASON_SENSITIVE;
}
}
return FINISH_REASON_OTHER;
}


static ChatModelRequest createModelListenerRequest(ChatCompletionRequest options,
List<ChatMessage> messages,
List<ToolSpecification> toolSpecifications) {
return ChatModelRequest.builder()
.model(options.getModel())
.temperature(options.getTemperature())
.topP(options.getTopP())
.maxTokens(options.getMaxTokens())
.messages(messages)
.toolSpecifications(toolSpecifications)
.build();
}

public static FinishReason finishReasonFrom(String finishReason) {
static ChatModelResponse createModelListenerResponse(String responseId,
String responseModel,
Response<AiMessage> response) {
if (response == null) {
return null;
}

return ChatModelResponse.builder()
.id(responseId)
.model(responseModel)
.tokenUsage(response.tokenUsage())
.finishReason(response.finishReason())
.aiMessage(response.content())
.build();
}

static boolean isSuccessFinishReason(FinishReason finishReason) {
return !CONTENT_FILTER.equals(finishReason) && !OTHER.equals(finishReason);
}

static FinishReason finishReasonFrom(String finishReason) {
if (finishReason == null) {
return null;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,37 +4,48 @@
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.listener.*;
import dev.langchain4j.model.output.FinishReason;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.zhipu.chat.ChatCompletionModel;
import dev.langchain4j.model.zhipu.chat.ChatCompletionRequest;
import dev.langchain4j.model.zhipu.chat.ChatCompletionResponse;
import dev.langchain4j.model.zhipu.spi.ZhipuAiChatModelBuilderFactory;
import lombok.Builder;
import lombok.extern.slf4j.Slf4j;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

import static dev.langchain4j.internal.RetryUtils.withRetry;
import static dev.langchain4j.internal.Utils.getOrDefault;
import static dev.langchain4j.internal.Utils.isNullOrEmpty;
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;
import static dev.langchain4j.internal.ValidationUtils.ensureNotEmpty;
import static dev.langchain4j.model.zhipu.DefaultZhipuAiHelper.*;
import static dev.langchain4j.model.zhipu.chat.ChatCompletionModel.GLM_4;
import static dev.langchain4j.model.zhipu.chat.ToolChoiceMode.AUTO;
import static dev.langchain4j.spi.ServiceHelper.loadFactories;
import static java.util.Collections.emptyList;
import static java.util.Collections.singletonList;

/**
* Represents an ZhipuAi language model with a chat completion interface, such as glm-3-turbo and glm-4.
* You can find description of parameters <a href="https://open.bigmodel.cn/dev/api">here</a>.
*/
@Slf4j
public class ZhipuAiChatModel implements ChatLanguageModel {

private final String baseUrl;
private final Double temperature;
private final Double topP;
private final String model;
private final Integer maxRetries;
private final Integer maxToken;
private final List<String> stops;
private final ZhipuAiClient client;
private final List<ChatModelListener> listeners;

@Builder
public ZhipuAiChatModel(
Expand All @@ -43,19 +54,22 @@ public ZhipuAiChatModel(
Double temperature,
Double topP,
String model,
List<String> stops,
Integer maxRetries,
Integer maxToken,
Boolean logRequests,
Boolean logResponses
Boolean logResponses,
List<ChatModelListener> listeners
) {
this.baseUrl = getOrDefault(baseUrl, "https://open.bigmodel.cn/");
this.temperature = getOrDefault(temperature, 0.7);
this.topP = topP;
this.stops = stops;
this.model = getOrDefault(model, GLM_4.toString());
this.maxRetries = getOrDefault(maxRetries, 3);
this.maxToken = getOrDefault(maxToken, 512);
this.listeners = listeners == null ? emptyList() : new ArrayList<>(listeners);
this.client = ZhipuAiClient.builder()
.baseUrl(this.baseUrl)
.baseUrl(getOrDefault(baseUrl, "https://open.bigmodel.cn/"))
.apiKey(apiKey)
.logRequests(getOrDefault(logRequests, false))
.logResponses(getOrDefault(logResponses, false))
Expand All @@ -80,23 +94,61 @@ public Response<AiMessage> generate(List<ChatMessage> messages, List<ToolSpecifi

ChatCompletionRequest.Builder requestBuilder = ChatCompletionRequest.builder()
.model(this.model)
.maxTokens(maxToken)
.maxTokens(this.maxToken)
.stream(false)
.topP(topP)
.temperature(temperature)
.topP(this.topP)
.stop(this.stops)
.temperature(this.temperature)
.toolChoice(AUTO)
.messages(toZhipuAiMessages(messages));

if (!isNullOrEmpty(toolSpecifications)) {
requestBuilder.tools(toTools(toolSpecifications));
}

ChatCompletionResponse response = withRetry(() -> client.chatCompletion(requestBuilder.build()), maxRetries);
return Response.from(
ChatCompletionRequest request = requestBuilder.build();
ChatModelRequest modelListenerRequest = createModelListenerRequest(request, messages, toolSpecifications);
Map<Object, Object> attributes = new ConcurrentHashMap<>();
ChatModelRequestContext requestContext = new ChatModelRequestContext(modelListenerRequest, attributes);
for (ChatModelListener chatModelListener : listeners) {
try {
chatModelListener.onRequest(requestContext);
} catch (Exception e) {
log.warn("Exception while calling model listener", e);
}
}

ChatCompletionResponse response = withRetry(() -> client.chatCompletion(request), maxRetries);

FinishReason finishReason = finishReasonFrom(response.getChoices().get(0).getFinishReason());

Response<AiMessage> messageResponse = Response.from(
aiMessageFrom(response),
tokenUsageFrom(response.getUsage()),
finishReasonFrom(response.getChoices().get(0).getFinishReason())
finishReason
);

listeners.forEach(listener -> {
try {
if (isSuccessFinishReason(finishReason)) {
listener.onResponse(new ChatModelResponseContext(
createModelListenerResponse(response.getId(), request.getModel(), messageResponse),
modelListenerRequest,
attributes
));
} else {
listener.onError(new ChatModelErrorContext(
new ZhipuAiException(messageResponse.content().text()),
modelListenerRequest,
null,
attributes
));
}
} catch (Exception e) {
log.warn("Exception while calling model listener", e);
}
});
return messageResponse;
}

@Override
Expand All @@ -107,5 +159,16 @@ public Response<AiMessage> generate(List<ChatMessage> messages, ToolSpecificatio
public static class ZhipuAiChatModelBuilder {
public ZhipuAiChatModelBuilder() {
}

public ZhipuAiChatModelBuilder model(ChatCompletionModel model) {
this.model = model.toString();
return this;
}

public ZhipuAiChatModelBuilder model(String model) {
ensureNotBlank(model, "model");
this.model = model;
return this;
}
}
}
Loading

0 comments on commit fc78f3f

Please sign in to comment.