forked from langchain4j/langchain4j
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Jina AI Embedding model integration (langchain4j#997)
## Context This pr is for integration of jina ai embedding model which is mentioned in the issue [973](langchain4j#973) ## Change 1. Since no jina sdk was available for java hence built client for the same . 2. Added method of embedding generation for both single input and multiple inputs 3. Default model used _jina-embeddings-v2-base-en_ ## Checklist Before submitting this PR, please check the following points: - [x] I have added unit and integration tests for my change - [x] All unit and integration tests in the module I have added/changed are green - [x] All unit and integration tests in the [core](https://github.com/langchain4j/langchain4j/tree/main/langchain4j-core) and [main](https://github.com/langchain4j/langchain4j/tree/main/langchain4j) modules are green - [ ] I have added/updated the [documentation](https://github.com/langchain4j/langchain4j/tree/main/docs/docs) - [ ] I have added an example in the [examples repo](https://github.com/langchain4j/langchain4j-examples) (only for "big" features) - [ ] I have added my new module in the [BOM](https://github.com/langchain4j/langchain4j/blob/main/langchain4j-bom/pom.xml) (only when a new module is added) ## Checklist for adding new embedding store integration - [ ] I have added a {NameOfIntegration}EmbeddingStoreIT that extends from either EmbeddingStoreIT or EmbeddingStoreWithFilteringIT
- Loading branch information
1 parent
e223963
commit 0ad92d5
Showing
10 changed files
with
326 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" | ||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd"> | ||
<modelVersion>4.0.0</modelVersion> | ||
<parent> | ||
<groupId>dev.langchain4j</groupId> | ||
<artifactId>langchain4j-parent</artifactId> | ||
<version>0.31.0-SNAPSHOT</version> | ||
<relativePath>../langchain4j-parent/pom.xml</relativePath> | ||
</parent> | ||
|
||
<artifactId>langchain4j-jina</artifactId> | ||
<name>LangChain4j :: Integration :: Jina</name> | ||
|
||
<dependencies> | ||
<dependency> | ||
<groupId>dev.langchain4j</groupId> | ||
<artifactId>langchain4j-core</artifactId> | ||
</dependency> | ||
<dependency> | ||
<groupId>org.projectlombok</groupId> | ||
<artifactId>lombok</artifactId> | ||
</dependency> | ||
<dependency> | ||
<groupId>com.squareup.retrofit2</groupId> | ||
<artifactId>retrofit</artifactId> | ||
</dependency> | ||
<dependency> | ||
<groupId>com.squareup.retrofit2</groupId> | ||
<artifactId>converter-gson</artifactId> | ||
</dependency> | ||
<dependency> | ||
<groupId>org.assertj</groupId> | ||
<artifactId>assertj-core</artifactId> | ||
</dependency> | ||
<dependency> | ||
<groupId>org.junit.jupiter</groupId> | ||
<artifactId>junit-jupiter-engine</artifactId> | ||
<scope>test</scope> | ||
</dependency> | ||
|
||
<dependency> | ||
<groupId>org.junit.jupiter</groupId> | ||
<artifactId>junit-jupiter-params</artifactId> | ||
<scope>test</scope> | ||
</dependency> | ||
|
||
<dependency> | ||
<groupId>org.tinylog</groupId> | ||
<artifactId>tinylog-impl</artifactId> | ||
<scope>test</scope> | ||
</dependency> | ||
|
||
<dependency> | ||
<groupId>org.tinylog</groupId> | ||
<artifactId>slf4j-tinylog</artifactId> | ||
<scope>test</scope> | ||
</dependency> | ||
</dependencies> | ||
|
||
</project> |
10 changes: 10 additions & 0 deletions
10
langchain4j-jina/src/main/java/dev/langchain4j/model/jina/EmbeddingRequest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
package dev.langchain4j.model.jina; | ||
|
||
import lombok.Builder; | ||
|
||
import java.util.List; | ||
@Builder | ||
public class EmbeddingRequest { | ||
String model; | ||
List<String> input; | ||
} |
10 changes: 10 additions & 0 deletions
10
langchain4j-jina/src/main/java/dev/langchain4j/model/jina/EmbeddingResponse.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
package dev.langchain4j.model.jina; | ||
|
||
import lombok.Data; | ||
|
||
import java.util.List; | ||
@Data | ||
public class EmbeddingResponse { | ||
Usage usage; | ||
List<JinaEmbedding> data; | ||
} |
14 changes: 14 additions & 0 deletions
14
langchain4j-jina/src/main/java/dev/langchain4j/model/jina/JinaApi.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
package dev.langchain4j.model.jina; | ||
|
||
import retrofit2.Call; | ||
import retrofit2.http.Body; | ||
import retrofit2.http.Header; | ||
import retrofit2.http.Headers; | ||
import retrofit2.http.POST; | ||
|
||
public interface JinaApi { | ||
@POST("v1/embeddings") | ||
@Headers({"Content-Type: application/json"}) | ||
Call<EmbeddingResponse> embed(@Body EmbeddingRequest request, @Header("Authorization") String authorizationHeader); | ||
|
||
} |
67 changes: 67 additions & 0 deletions
67
langchain4j-jina/src/main/java/dev/langchain4j/model/jina/JinaClient.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
package dev.langchain4j.model.jina; | ||
|
||
import com.google.gson.Gson; | ||
import com.google.gson.GsonBuilder; | ||
import lombok.Builder; | ||
import okhttp3.OkHttpClient; | ||
import retrofit2.Retrofit; | ||
import retrofit2.converter.gson.GsonConverterFactory; | ||
|
||
import java.io.IOException; | ||
import java.time.Duration; | ||
|
||
import static com.google.gson.FieldNamingPolicy.LOWER_CASE_WITH_UNDERSCORES; | ||
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank; | ||
|
||
public class JinaClient { | ||
private static final Gson GSON = new GsonBuilder() | ||
.setFieldNamingPolicy(LOWER_CASE_WITH_UNDERSCORES) | ||
.setPrettyPrinting() | ||
.create(); | ||
|
||
private final JinaApi jinaApi; | ||
private final String authorizationHeader; | ||
|
||
@Builder | ||
JinaClient(String baseUrl, String apiKey, Duration timeout){ | ||
OkHttpClient.Builder okHttpClientBuilder = new OkHttpClient.Builder() | ||
.callTimeout(timeout) | ||
.connectTimeout(timeout) | ||
.readTimeout(timeout) | ||
.writeTimeout(timeout); | ||
Retrofit retrofit = new Retrofit.Builder() | ||
.baseUrl(baseUrl) | ||
.client(okHttpClientBuilder.build()) | ||
.addConverterFactory(GsonConverterFactory.create(GSON)) | ||
.build(); | ||
|
||
|
||
this.jinaApi = retrofit.create(JinaApi.class); | ||
this.authorizationHeader = "Bearer " + ensureNotBlank(apiKey, "apiKey"); | ||
} | ||
|
||
public EmbeddingResponse embed(EmbeddingRequest request) { | ||
try { | ||
retrofit2.Response<EmbeddingResponse> retrofitResponse | ||
= jinaApi.embed(request, authorizationHeader).execute(); | ||
|
||
if (retrofitResponse.isSuccessful()) { | ||
return retrofitResponse.body(); | ||
} else { | ||
throw toException(retrofitResponse); | ||
} | ||
} catch (IOException e) { | ||
throw new RuntimeException(e); | ||
} | ||
} | ||
|
||
|
||
|
||
private static RuntimeException toException(retrofit2.Response<?> response) throws IOException { | ||
int code = response.code(); | ||
String body = response.errorBody().string(); | ||
String errorMessage = String.format("status code: %s; body: %s", code, body); | ||
return new RuntimeException(errorMessage); | ||
} | ||
|
||
} |
13 changes: 13 additions & 0 deletions
13
langchain4j-jina/src/main/java/dev/langchain4j/model/jina/JinaEmbedding.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
package dev.langchain4j.model.jina; | ||
|
||
import dev.langchain4j.data.embedding.Embedding; | ||
|
||
public class JinaEmbedding { | ||
long index; | ||
float[] embedding; | ||
String object; | ||
|
||
public Embedding toEmbedding(){ | ||
return Embedding.from(embedding); | ||
} | ||
} |
68 changes: 68 additions & 0 deletions
68
langchain4j-jina/src/main/java/dev/langchain4j/model/jina/JinaEmbeddingModel.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
package dev.langchain4j.model.jina; | ||
|
||
import dev.langchain4j.data.embedding.Embedding; | ||
import dev.langchain4j.data.segment.TextSegment; | ||
import dev.langchain4j.model.embedding.EmbeddingModel; | ||
import dev.langchain4j.model.output.Response; | ||
import dev.langchain4j.model.output.TokenUsage; | ||
import lombok.Builder; | ||
|
||
import java.time.Duration; | ||
import java.util.List; | ||
|
||
import static dev.langchain4j.internal.RetryUtils.withRetry; | ||
import static dev.langchain4j.internal.Utils.getOrDefault; | ||
import static java.time.Duration.ofSeconds; | ||
import static java.util.stream.Collectors.toList; | ||
|
||
/** | ||
* An integration with Nomic Atlas's Text Embeddings API. | ||
* See more details <a href="https://api.jina.ai/redoc#tag/embeddings">Jina API reference</a> | ||
*/ | ||
|
||
public class JinaEmbeddingModel implements EmbeddingModel { | ||
|
||
|
||
private static final String DEFAULT_BASE_URL = "https://api.jina.ai/"; | ||
|
||
private final JinaClient client; | ||
private final String modelName; | ||
private final Integer maxRetries; | ||
|
||
@Builder | ||
public JinaEmbeddingModel(String baseUrl, | ||
String apiKey, | ||
String modelName, | ||
Duration timeout, | ||
Integer maxRetries) { | ||
this.client = JinaClient.builder() | ||
.baseUrl(getOrDefault(baseUrl,DEFAULT_BASE_URL)) | ||
.apiKey(apiKey) | ||
.timeout(getOrDefault(timeout, ofSeconds(60))) | ||
.build(); | ||
this.modelName = getOrDefault(modelName, "jina-embeddings-v2-base-en"); | ||
this.maxRetries = getOrDefault(maxRetries, 3); | ||
} | ||
|
||
public static JinaEmbeddingModel withApiKey(String apiKey) { | ||
return JinaEmbeddingModel.builder().apiKey(apiKey).build(); | ||
} | ||
|
||
|
||
@Override | ||
public Response<List<Embedding>> embedAll(List<TextSegment> textSegments) { | ||
EmbeddingRequest request = EmbeddingRequest.builder() | ||
.model(modelName) | ||
.input(textSegments.stream().map(TextSegment::text).collect(toList())) | ||
.build(); | ||
|
||
EmbeddingResponse response = withRetry(() -> client.embed(request), maxRetries); | ||
|
||
List<Embedding> embeddings = response.getData().stream() | ||
.map(JinaEmbedding::toEmbedding).collect(toList()); | ||
|
||
TokenUsage tokenUsage = new TokenUsage(response.getUsage().getPromptTokens(),0 ); | ||
return Response.from(embeddings,tokenUsage); | ||
} | ||
|
||
} |
12 changes: 12 additions & 0 deletions
12
langchain4j-jina/src/main/java/dev/langchain4j/model/jina/Usage.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
package dev.langchain4j.model.jina; | ||
|
||
import com.google.gson.annotations.SerializedName; | ||
import lombok.Getter; | ||
|
||
@Getter | ||
class Usage { | ||
@SerializedName("total_tokens") | ||
private Integer totalTokens; | ||
@SerializedName("prompt_tokens") | ||
private Integer promptTokens; | ||
} |
71 changes: 71 additions & 0 deletions
71
langchain4j-jina/src/test/java/dev/langchain4j/model/jina/JinaEmbeddingModelIT.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
package dev.langchain4j.model.jina; | ||
|
||
import dev.langchain4j.data.embedding.Embedding; | ||
import dev.langchain4j.data.segment.TextSegment; | ||
import dev.langchain4j.model.embedding.EmbeddingModel; | ||
import dev.langchain4j.model.output.Response; | ||
import dev.langchain4j.store.embedding.CosineSimilarity; | ||
import org.junit.jupiter.api.Test; | ||
|
||
import java.util.List; | ||
|
||
import static java.time.Duration.ofSeconds; | ||
import static java.util.Arrays.asList; | ||
import static org.assertj.core.api.Assertions.assertThat; | ||
|
||
|
||
|
||
public class JinaEmbeddingModelIT { | ||
@Test | ||
public void should_embed_single_text() { | ||
|
||
// given | ||
EmbeddingModel model = JinaEmbeddingModel.withApiKey(System.getenv("JINA_AI_API_KEY")); | ||
|
||
String text = "hello"; | ||
|
||
// when | ||
Response<Embedding> response = model.embed(text); | ||
|
||
// then | ||
assertThat(response.content().dimension()).isEqualTo(768); | ||
|
||
assertThat(response.tokenUsage().inputTokenCount()).isEqualTo(3); | ||
assertThat(response.tokenUsage().outputTokenCount()).isEqualTo(0); | ||
assertThat(response.tokenUsage().totalTokenCount()).isEqualTo(3); | ||
} | ||
|
||
@Test | ||
public void should_embed_multiple_segments() { | ||
|
||
// given | ||
EmbeddingModel model = JinaEmbeddingModel.builder() | ||
.baseUrl("https://api.jina.ai/") | ||
.apiKey(System.getenv("JINA_AI_API_KEY")) | ||
.modelName("jina-embeddings-v2-base-en") | ||
.timeout(ofSeconds(10)) | ||
.maxRetries(2) | ||
.build(); | ||
|
||
TextSegment segment1 = TextSegment.from("hello"); | ||
TextSegment segment2 = TextSegment.from("hi"); | ||
|
||
// when | ||
Response<List<Embedding>> response = model.embedAll(asList(segment1, segment2)); | ||
|
||
// then | ||
assertThat(response.content()).hasSize(2); | ||
|
||
Embedding embedding1 = response.content().get(0); | ||
assertThat(embedding1.dimension()).isEqualTo(768); | ||
|
||
Embedding embedding2 = response.content().get(1); | ||
assertThat(embedding2.dimension()).isEqualTo(768); | ||
|
||
assertThat(CosineSimilarity.between(embedding1, embedding2)).isGreaterThan(0.9); | ||
|
||
assertThat(response.tokenUsage().inputTokenCount()).isEqualTo(6); | ||
assertThat(response.tokenUsage().outputTokenCount()).isEqualTo(0); | ||
assertThat(response.tokenUsage().totalTokenCount()).isEqualTo(6); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters