forked from langchain4j/langchain4j
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add ZhipuAI integration (langchain4j#558)
ZhipuAI is a large model focusing on Chinese cognition <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Introduced serialization and deserialization for assistant messages. - Enhanced utility methods for data conversion in Zhipu AI processing. - Implemented JSON serialization/deserialization support. - Defined an interface for Zhipu AI service interactions. - Introduced classes for handling chat completions and embedding requests with Zhipu AI. - Provided structure for chat messages, choices, models, requests, and responses. - Added classes for function calls, parameters, tool interactions, and web searches within chats. - Established data structures for embedding information and requests. - Implemented builders for chat and embedding model instances. - **Tests** - Added integration tests for chat model, embedding model, and streaming chat model functionalities. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
- Loading branch information
1 parent
e54037b
commit 8fcd192
Showing
46 changed files
with
3,054 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
<?xml version="1.0" encoding="UTF-8"?> | ||
<project xmlns="http://maven.apache.org/POM/4.0.0" | ||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" | ||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> | ||
<modelVersion>4.0.0</modelVersion> | ||
|
||
<parent> | ||
<groupId>dev.langchain4j</groupId> | ||
<artifactId>langchain4j-parent</artifactId> | ||
<version>0.28.0-SNAPSHOT</version> | ||
<relativePath>../langchain4j-parent/pom.xml</relativePath> | ||
</parent> | ||
|
||
<artifactId>langchain4j-zhipu-ai</artifactId> | ||
<packaging>jar</packaging> | ||
|
||
<name>LangChain4j :: Integration :: Zhipu AI</name> | ||
|
||
<dependencies> | ||
|
||
<dependency> | ||
<groupId>dev.langchain4j</groupId> | ||
<artifactId>langchain4j-core</artifactId> | ||
</dependency> | ||
|
||
<dependency> | ||
<groupId>com.squareup.retrofit2</groupId> | ||
<artifactId>retrofit</artifactId> | ||
</dependency> | ||
|
||
<dependency> | ||
<groupId>com.squareup.retrofit2</groupId> | ||
<artifactId>converter-gson</artifactId> | ||
</dependency> | ||
|
||
<dependency> | ||
<groupId>com.squareup.okhttp3</groupId> | ||
<artifactId>okhttp</artifactId> | ||
</dependency> | ||
|
||
<dependency> | ||
<groupId>com.squareup.okhttp3</groupId> | ||
<artifactId>okhttp-sse</artifactId> | ||
<version>${okhttp.version}</version> | ||
</dependency> | ||
|
||
<dependency> | ||
<groupId>io.jsonwebtoken</groupId> | ||
<artifactId>jjwt</artifactId> | ||
<version>0.12.3</version> | ||
</dependency> | ||
|
||
<dependency> | ||
<groupId>com.google.guava</groupId> | ||
<artifactId>guava</artifactId> | ||
<version>32.0.0-jre</version> | ||
</dependency> | ||
|
||
<dependency> | ||
<groupId>org.projectlombok</groupId> | ||
<artifactId>lombok</artifactId> | ||
<scope>provided</scope> | ||
</dependency> | ||
|
||
<dependency> | ||
<groupId>org.junit.jupiter</groupId> | ||
<artifactId>junit-jupiter-engine</artifactId> | ||
<scope>test</scope> | ||
</dependency> | ||
|
||
<dependency> | ||
<groupId>org.assertj</groupId> | ||
<artifactId>assertj-core</artifactId> | ||
<scope>test</scope> | ||
</dependency> | ||
|
||
<dependency> | ||
<groupId>org.tinylog</groupId> | ||
<artifactId>tinylog-impl</artifactId> | ||
<scope>test</scope> | ||
</dependency> | ||
|
||
<dependency> | ||
<groupId>org.tinylog</groupId> | ||
<artifactId>slf4j-tinylog</artifactId> | ||
<scope>test</scope> | ||
</dependency> | ||
|
||
<dependency> | ||
<groupId>dev.langchain4j</groupId> | ||
<artifactId>langchain4j-core</artifactId> | ||
<classifier>tests</classifier> | ||
<type>test-jar</type> | ||
<scope>test</scope> | ||
</dependency> | ||
|
||
</dependencies> | ||
|
||
<licenses> | ||
<license> | ||
<name>Apache-2.0</name> | ||
<url>https://www.apache.org/licenses/LICENSE-2.0.txt</url> | ||
<distribution>repo</distribution> | ||
<comments>A business-friendly OSS license</comments> | ||
</license> | ||
</licenses> | ||
|
||
</project> |
77 changes: 77 additions & 0 deletions
77
...ain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/AssistantMessageTypeAdapter.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
package dev.langchain4j.model.zhipu; | ||
|
||
import com.google.gson.Gson; | ||
import com.google.gson.TypeAdapter; | ||
import com.google.gson.TypeAdapterFactory; | ||
import com.google.gson.reflect.TypeToken; | ||
import com.google.gson.stream.JsonReader; | ||
import com.google.gson.stream.JsonWriter; | ||
import dev.langchain4j.model.zhipu.chat.AssistantMessage; | ||
import dev.langchain4j.model.zhipu.chat.ToolCall; | ||
|
||
import java.io.IOException; | ||
import java.util.List; | ||
|
||
class AssistantMessageTypeAdapter extends TypeAdapter<AssistantMessage> { | ||
|
||
static final TypeAdapterFactory ASSISTANT_MESSAGE_TYPE_ADAPTER_FACTORY = new TypeAdapterFactory() { | ||
|
||
@Override | ||
@SuppressWarnings("unchecked") | ||
public <T> TypeAdapter<T> create(Gson gson, TypeToken<T> type) { | ||
if (type.getRawType() != AssistantMessage.class) { | ||
return null; | ||
} | ||
TypeAdapter<AssistantMessage> delegate = | ||
(TypeAdapter<AssistantMessage>) gson.getDelegateAdapter(this, type); | ||
return (TypeAdapter<T>) new AssistantMessageTypeAdapter(delegate); | ||
} | ||
}; | ||
|
||
private final TypeAdapter<AssistantMessage> delegate; | ||
|
||
private AssistantMessageTypeAdapter(TypeAdapter<AssistantMessage> delegate) { | ||
this.delegate = delegate; | ||
} | ||
|
||
@Override | ||
public void write(JsonWriter out, AssistantMessage assistantMessage) throws IOException { | ||
out.beginObject(); | ||
|
||
out.name("role"); | ||
out.value(assistantMessage.getRole().toString().toLowerCase()); | ||
|
||
out.name("content"); | ||
if (assistantMessage.getContent() == null) { | ||
boolean serializeNulls = out.getSerializeNulls(); | ||
out.setSerializeNulls(true); | ||
out.nullValue(); // serialize "content": null | ||
out.setSerializeNulls(serializeNulls); | ||
} else { | ||
out.value(assistantMessage.getContent()); | ||
} | ||
|
||
if (assistantMessage.getName() != null) { | ||
out.name("name"); | ||
out.value(assistantMessage.getName()); | ||
} | ||
|
||
List<ToolCall> toolCalls = assistantMessage.getToolCalls(); | ||
if (toolCalls != null && !toolCalls.isEmpty()) { | ||
out.name("tool_calls"); | ||
out.beginArray(); | ||
TypeAdapter<ToolCall> toolCallTypeAdapter = Json.GSON.getAdapter(ToolCall.class); | ||
for (ToolCall toolCall : toolCalls) { | ||
toolCallTypeAdapter.write(out, toolCall); | ||
} | ||
out.endArray(); | ||
} | ||
|
||
out.endObject(); | ||
} | ||
|
||
@Override | ||
public AssistantMessage read(JsonReader in) throws IOException { | ||
return delegate.read(in); | ||
} | ||
} |
86 changes: 86 additions & 0 deletions
86
langchain4j-zhipu-ai/src/main/java/dev/langchain4j/model/zhipu/AuthorizationInterceptor.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
package dev.langchain4j.model.zhipu; | ||
|
||
import com.google.common.cache.Cache; | ||
import com.google.common.cache.CacheBuilder; | ||
import io.jsonwebtoken.Jwts; | ||
import io.jsonwebtoken.security.MacAlgorithm; | ||
import okhttp3.Interceptor; | ||
import okhttp3.Request; | ||
import okhttp3.Response; | ||
|
||
import javax.crypto.spec.SecretKeySpec; | ||
import java.io.IOException; | ||
import java.lang.reflect.Constructor; | ||
import java.lang.reflect.InvocationTargetException; | ||
import java.time.Duration; | ||
import java.util.HashMap; | ||
import java.util.Map; | ||
|
||
import static dev.langchain4j.internal.Utils.getOrDefault; | ||
import static java.lang.System.currentTimeMillis; | ||
import static java.nio.charset.StandardCharsets.UTF_8; | ||
|
||
class AuthorizationInterceptor implements Interceptor { | ||
|
||
private static final long expireMillis = 1000 * 60 * 30; | ||
private static final String id = "HS256"; | ||
private static final String jcaName = "HmacSHA256"; | ||
private static final MacAlgorithm macAlgorithm; | ||
|
||
static { | ||
try { | ||
//create a custom MacAlgorithm with a custom minKeyBitLength | ||
int minKeyBitLength = 128; | ||
Class<?> c = Class.forName("io.jsonwebtoken.impl.security.DefaultMacAlgorithm"); | ||
Constructor<?> ctor = c.getDeclaredConstructor(String.class, String.class, int.class); | ||
ctor.setAccessible(true); | ||
macAlgorithm = (MacAlgorithm) ctor.newInstance(id, jcaName, minKeyBitLength); | ||
} catch (ClassNotFoundException | NoSuchMethodException | InstantiationException | IllegalAccessException | | ||
InvocationTargetException e) { | ||
throw new RuntimeException(e); | ||
} | ||
} | ||
|
||
private final String apiKey; | ||
private final Cache<String, String> cache = CacheBuilder.newBuilder() | ||
.expireAfterWrite(Duration.ofMillis(expireMillis)) | ||
.build(); | ||
|
||
|
||
public AuthorizationInterceptor(String apiKey) { | ||
this.apiKey = apiKey; | ||
} | ||
|
||
@Override | ||
public Response intercept(Chain chain) throws IOException { | ||
String token = getOrDefault(cache.getIfPresent(this.apiKey), generateToken()); | ||
Request request = chain.request() | ||
.newBuilder() | ||
.addHeader("Authorization", "Bearer " + token) | ||
.removeHeader("Accept") | ||
.build(); | ||
return chain.proceed(request); | ||
} | ||
|
||
private String generateToken() { | ||
String[] apiKeyParts = this.apiKey.split("\\."); | ||
String keyId = apiKeyParts[0]; | ||
String secret = apiKeyParts[1]; | ||
Map<String, Object> payload = new HashMap<>(3); | ||
payload.put("api_key", keyId); | ||
payload.put("exp", currentTimeMillis() + expireMillis); | ||
payload.put("timestamp", currentTimeMillis()); | ||
|
||
String token = Jwts.builder() | ||
.header() | ||
.add("alg", id) | ||
.add("sign_type", "SIGN") | ||
.and() | ||
.content(Json.toJson(payload)) | ||
.signWith(new SecretKeySpec(secret.getBytes(UTF_8), jcaName), macAlgorithm) | ||
.compact(); | ||
cache.put(this.apiKey, token); | ||
return token; | ||
} | ||
|
||
} |
Oops, something went wrong.