Skip to content

Commit

Permalink
add ZhipuAI integration (langchain4j#558)
Browse files Browse the repository at this point in the history
ZhipuAI is a large model focusing on Chinese cognition

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit


- **New Features**
	- Introduced serialization and deserialization for assistant messages.
	- Enhanced utility methods for data conversion in Zhipu AI processing.
	- Implemented JSON serialization/deserialization support.
	- Defined an interface for Zhipu AI service interactions.
- Introduced classes for handling chat completions and embedding
requests with Zhipu AI.
- Provided structure for chat messages, choices, models, requests, and
responses.
- Added classes for function calls, parameters, tool interactions, and
web searches within chats.
	- Established data structures for embedding information and requests.
	- Implemented builders for chat and embedding model instances.
- **Tests**
- Added integration tests for chat model, embedding model, and streaming
chat model functionalities.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->
  • Loading branch information
1402564807 authored Mar 11, 2024
1 parent e54037b commit 8fcd192
Show file tree
Hide file tree
Showing 46 changed files with 3,054 additions and 0 deletions.
6 changes: 6 additions & 0 deletions langchain4j-bom/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,12 @@
<version>${project.version}</version>
</dependency>

<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-zhipu-ai</artifactId>
<version>${project.version}</version>
</dependency>

<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-ollama</artifactId>
Expand Down
108 changes: 108 additions & 0 deletions langchain4j-zhipu-ai/pom.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>

<parent>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-parent</artifactId>
<version>0.28.0-SNAPSHOT</version>
<relativePath>../langchain4j-parent/pom.xml</relativePath>
</parent>

<artifactId>langchain4j-zhipu-ai</artifactId>
<packaging>jar</packaging>

<name>LangChain4j :: Integration :: Zhipu AI</name>

<dependencies>

<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-core</artifactId>
</dependency>

<dependency>
<groupId>com.squareup.retrofit2</groupId>
<artifactId>retrofit</artifactId>
</dependency>

<dependency>
<groupId>com.squareup.retrofit2</groupId>
<artifactId>converter-gson</artifactId>
</dependency>

<dependency>
<groupId>com.squareup.okhttp3</groupId>
<artifactId>okhttp</artifactId>
</dependency>

<dependency>
<groupId>com.squareup.okhttp3</groupId>
<artifactId>okhttp-sse</artifactId>
<version>${okhttp.version}</version>
</dependency>

<dependency>
<groupId>io.jsonwebtoken</groupId>
<artifactId>jjwt</artifactId>
<version>0.12.3</version>
</dependency>

<dependency>
<groupId>com.google.guava</groupId>
<artifactId>guava</artifactId>
<version>32.0.0-jre</version>
</dependency>

<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
<scope>provided</scope>
</dependency>

<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-engine</artifactId>
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.assertj</groupId>
<artifactId>assertj-core</artifactId>
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.tinylog</groupId>
<artifactId>tinylog-impl</artifactId>
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.tinylog</groupId>
<artifactId>slf4j-tinylog</artifactId>
<scope>test</scope>
</dependency>

<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-core</artifactId>
<classifier>tests</classifier>
<type>test-jar</type>
<scope>test</scope>
</dependency>

</dependencies>

<licenses>
<license>
<name>Apache-2.0</name>
<url>https://www.apache.org/licenses/LICENSE-2.0.txt</url>
<distribution>repo</distribution>
<comments>A business-friendly OSS license</comments>
</license>
</licenses>

</project>
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
package dev.langchain4j.model.zhipu;

import com.google.gson.Gson;
import com.google.gson.TypeAdapter;
import com.google.gson.TypeAdapterFactory;
import com.google.gson.reflect.TypeToken;
import com.google.gson.stream.JsonReader;
import com.google.gson.stream.JsonWriter;
import dev.langchain4j.model.zhipu.chat.AssistantMessage;
import dev.langchain4j.model.zhipu.chat.ToolCall;

import java.io.IOException;
import java.util.List;

class AssistantMessageTypeAdapter extends TypeAdapter<AssistantMessage> {

static final TypeAdapterFactory ASSISTANT_MESSAGE_TYPE_ADAPTER_FACTORY = new TypeAdapterFactory() {

@Override
@SuppressWarnings("unchecked")
public <T> TypeAdapter<T> create(Gson gson, TypeToken<T> type) {
if (type.getRawType() != AssistantMessage.class) {
return null;
}
TypeAdapter<AssistantMessage> delegate =
(TypeAdapter<AssistantMessage>) gson.getDelegateAdapter(this, type);
return (TypeAdapter<T>) new AssistantMessageTypeAdapter(delegate);
}
};

private final TypeAdapter<AssistantMessage> delegate;

private AssistantMessageTypeAdapter(TypeAdapter<AssistantMessage> delegate) {
this.delegate = delegate;
}

@Override
public void write(JsonWriter out, AssistantMessage assistantMessage) throws IOException {
out.beginObject();

out.name("role");
out.value(assistantMessage.getRole().toString().toLowerCase());

out.name("content");
if (assistantMessage.getContent() == null) {
boolean serializeNulls = out.getSerializeNulls();
out.setSerializeNulls(true);
out.nullValue(); // serialize "content": null
out.setSerializeNulls(serializeNulls);
} else {
out.value(assistantMessage.getContent());
}

if (assistantMessage.getName() != null) {
out.name("name");
out.value(assistantMessage.getName());
}

List<ToolCall> toolCalls = assistantMessage.getToolCalls();
if (toolCalls != null && !toolCalls.isEmpty()) {
out.name("tool_calls");
out.beginArray();
TypeAdapter<ToolCall> toolCallTypeAdapter = Json.GSON.getAdapter(ToolCall.class);
for (ToolCall toolCall : toolCalls) {
toolCallTypeAdapter.write(out, toolCall);
}
out.endArray();
}

out.endObject();
}

@Override
public AssistantMessage read(JsonReader in) throws IOException {
return delegate.read(in);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
package dev.langchain4j.model.zhipu;

import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;
import io.jsonwebtoken.Jwts;
import io.jsonwebtoken.security.MacAlgorithm;
import okhttp3.Interceptor;
import okhttp3.Request;
import okhttp3.Response;

import javax.crypto.spec.SecretKeySpec;
import java.io.IOException;
import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;
import java.time.Duration;
import java.util.HashMap;
import java.util.Map;

import static dev.langchain4j.internal.Utils.getOrDefault;
import static java.lang.System.currentTimeMillis;
import static java.nio.charset.StandardCharsets.UTF_8;

class AuthorizationInterceptor implements Interceptor {

private static final long expireMillis = 1000 * 60 * 30;
private static final String id = "HS256";
private static final String jcaName = "HmacSHA256";
private static final MacAlgorithm macAlgorithm;

static {
try {
//create a custom MacAlgorithm with a custom minKeyBitLength
int minKeyBitLength = 128;
Class<?> c = Class.forName("io.jsonwebtoken.impl.security.DefaultMacAlgorithm");
Constructor<?> ctor = c.getDeclaredConstructor(String.class, String.class, int.class);
ctor.setAccessible(true);
macAlgorithm = (MacAlgorithm) ctor.newInstance(id, jcaName, minKeyBitLength);
} catch (ClassNotFoundException | NoSuchMethodException | InstantiationException | IllegalAccessException |
InvocationTargetException e) {
throw new RuntimeException(e);
}
}

private final String apiKey;
private final Cache<String, String> cache = CacheBuilder.newBuilder()
.expireAfterWrite(Duration.ofMillis(expireMillis))
.build();


public AuthorizationInterceptor(String apiKey) {
this.apiKey = apiKey;
}

@Override
public Response intercept(Chain chain) throws IOException {
String token = getOrDefault(cache.getIfPresent(this.apiKey), generateToken());
Request request = chain.request()
.newBuilder()
.addHeader("Authorization", "Bearer " + token)
.removeHeader("Accept")
.build();
return chain.proceed(request);
}

private String generateToken() {
String[] apiKeyParts = this.apiKey.split("\\.");
String keyId = apiKeyParts[0];
String secret = apiKeyParts[1];
Map<String, Object> payload = new HashMap<>(3);
payload.put("api_key", keyId);
payload.put("exp", currentTimeMillis() + expireMillis);
payload.put("timestamp", currentTimeMillis());

String token = Jwts.builder()
.header()
.add("alg", id)
.add("sign_type", "SIGN")
.and()
.content(Json.toJson(payload))
.signWith(new SecretKeySpec(secret.getBytes(UTF_8), jcaName), macAlgorithm)
.compact();
cache.put(this.apiKey, token);
return token;
}

}
Loading

0 comments on commit 8fcd192

Please sign in to comment.