Skip to content

Commit

Permalink
Jina AI Embedding model integration (langchain4j#997)
Browse files Browse the repository at this point in the history
## Context
This pr is for integration of jina ai embedding model which is mentioned
in the issue
[973](langchain4j#973)

## Change
1. Since no jina sdk was available for java hence built client for the
same .
2. Added method of embedding generation for both single input and
multiple inputs
3. Default model used _jina-embeddings-v2-base-en_

## Checklist
Before submitting this PR, please check the following points:
- [x] I have added unit and integration tests for my change
- [x] All unit and integration tests in the module I have added/changed
are green
- [x] All unit and integration tests in the
[core](https://github.com/langchain4j/langchain4j/tree/main/langchain4j-core)
and
[main](https://github.com/langchain4j/langchain4j/tree/main/langchain4j)
modules are green
- [ ] I have added/updated the
[documentation](https://github.com/langchain4j/langchain4j/tree/main/docs/docs)
- [ ] I have added an example in the [examples
repo](https://github.com/langchain4j/langchain4j-examples) (only for
"big" features)
- [ ] I have added my new module in the
[BOM](https://github.com/langchain4j/langchain4j/blob/main/langchain4j-bom/pom.xml)
(only when a new module is added)

## Checklist for adding new embedding store integration
- [ ] I have added a {NameOfIntegration}EmbeddingStoreIT that extends
from either EmbeddingStoreIT or EmbeddingStoreWithFilteringIT
  • Loading branch information
lucifer-Hell authored May 22, 2024
1 parent e223963 commit 0ad92d5
Show file tree
Hide file tree
Showing 10 changed files with 326 additions and 0 deletions.
60 changes: 60 additions & 0 deletions langchain4j-jina/pom.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-parent</artifactId>
<version>0.31.0-SNAPSHOT</version>
<relativePath>../langchain4j-parent/pom.xml</relativePath>
</parent>

<artifactId>langchain4j-jina</artifactId>
<name>LangChain4j :: Integration :: Jina</name>

<dependencies>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-core</artifactId>
</dependency>
<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
</dependency>
<dependency>
<groupId>com.squareup.retrofit2</groupId>
<artifactId>retrofit</artifactId>
</dependency>
<dependency>
<groupId>com.squareup.retrofit2</groupId>
<artifactId>converter-gson</artifactId>
</dependency>
<dependency>
<groupId>org.assertj</groupId>
<artifactId>assertj-core</artifactId>
</dependency>
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-engine</artifactId>
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-params</artifactId>
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.tinylog</groupId>
<artifactId>tinylog-impl</artifactId>
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.tinylog</groupId>
<artifactId>slf4j-tinylog</artifactId>
<scope>test</scope>
</dependency>
</dependencies>

</project>
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package dev.langchain4j.model.jina;

import lombok.Builder;

import java.util.List;
@Builder
public class EmbeddingRequest {
String model;
List<String> input;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package dev.langchain4j.model.jina;

import lombok.Data;

import java.util.List;
@Data
public class EmbeddingResponse {
Usage usage;
List<JinaEmbedding> data;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package dev.langchain4j.model.jina;

import retrofit2.Call;
import retrofit2.http.Body;
import retrofit2.http.Header;
import retrofit2.http.Headers;
import retrofit2.http.POST;

public interface JinaApi {
@POST("v1/embeddings")
@Headers({"Content-Type: application/json"})
Call<EmbeddingResponse> embed(@Body EmbeddingRequest request, @Header("Authorization") String authorizationHeader);

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package dev.langchain4j.model.jina;

import com.google.gson.Gson;
import com.google.gson.GsonBuilder;
import lombok.Builder;
import okhttp3.OkHttpClient;
import retrofit2.Retrofit;
import retrofit2.converter.gson.GsonConverterFactory;

import java.io.IOException;
import java.time.Duration;

import static com.google.gson.FieldNamingPolicy.LOWER_CASE_WITH_UNDERSCORES;
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;

public class JinaClient {
private static final Gson GSON = new GsonBuilder()
.setFieldNamingPolicy(LOWER_CASE_WITH_UNDERSCORES)
.setPrettyPrinting()
.create();

private final JinaApi jinaApi;
private final String authorizationHeader;

@Builder
JinaClient(String baseUrl, String apiKey, Duration timeout){
OkHttpClient.Builder okHttpClientBuilder = new OkHttpClient.Builder()
.callTimeout(timeout)
.connectTimeout(timeout)
.readTimeout(timeout)
.writeTimeout(timeout);
Retrofit retrofit = new Retrofit.Builder()
.baseUrl(baseUrl)
.client(okHttpClientBuilder.build())
.addConverterFactory(GsonConverterFactory.create(GSON))
.build();


this.jinaApi = retrofit.create(JinaApi.class);
this.authorizationHeader = "Bearer " + ensureNotBlank(apiKey, "apiKey");
}

public EmbeddingResponse embed(EmbeddingRequest request) {
try {
retrofit2.Response<EmbeddingResponse> retrofitResponse
= jinaApi.embed(request, authorizationHeader).execute();

if (retrofitResponse.isSuccessful()) {
return retrofitResponse.body();
} else {
throw toException(retrofitResponse);
}
} catch (IOException e) {
throw new RuntimeException(e);
}
}



private static RuntimeException toException(retrofit2.Response<?> response) throws IOException {
int code = response.code();
String body = response.errorBody().string();
String errorMessage = String.format("status code: %s; body: %s", code, body);
return new RuntimeException(errorMessage);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package dev.langchain4j.model.jina;

import dev.langchain4j.data.embedding.Embedding;

public class JinaEmbedding {
long index;
float[] embedding;
String object;

public Embedding toEmbedding(){
return Embedding.from(embedding);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package dev.langchain4j.model.jina;

import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import lombok.Builder;

import java.time.Duration;
import java.util.List;

import static dev.langchain4j.internal.RetryUtils.withRetry;
import static dev.langchain4j.internal.Utils.getOrDefault;
import static java.time.Duration.ofSeconds;
import static java.util.stream.Collectors.toList;

/**
* An integration with Nomic Atlas's Text Embeddings API.
* See more details <a href="https://api.jina.ai/redoc#tag/embeddings">Jina API reference</a>
*/

public class JinaEmbeddingModel implements EmbeddingModel {


private static final String DEFAULT_BASE_URL = "https://api.jina.ai/";

private final JinaClient client;
private final String modelName;
private final Integer maxRetries;

@Builder
public JinaEmbeddingModel(String baseUrl,
String apiKey,
String modelName,
Duration timeout,
Integer maxRetries) {
this.client = JinaClient.builder()
.baseUrl(getOrDefault(baseUrl,DEFAULT_BASE_URL))
.apiKey(apiKey)
.timeout(getOrDefault(timeout, ofSeconds(60)))
.build();
this.modelName = getOrDefault(modelName, "jina-embeddings-v2-base-en");
this.maxRetries = getOrDefault(maxRetries, 3);
}

public static JinaEmbeddingModel withApiKey(String apiKey) {
return JinaEmbeddingModel.builder().apiKey(apiKey).build();
}


@Override
public Response<List<Embedding>> embedAll(List<TextSegment> textSegments) {
EmbeddingRequest request = EmbeddingRequest.builder()
.model(modelName)
.input(textSegments.stream().map(TextSegment::text).collect(toList()))
.build();

EmbeddingResponse response = withRetry(() -> client.embed(request), maxRetries);

List<Embedding> embeddings = response.getData().stream()
.map(JinaEmbedding::toEmbedding).collect(toList());

TokenUsage tokenUsage = new TokenUsage(response.getUsage().getPromptTokens(),0 );
return Response.from(embeddings,tokenUsage);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package dev.langchain4j.model.jina;

import com.google.gson.annotations.SerializedName;
import lombok.Getter;

@Getter
class Usage {
@SerializedName("total_tokens")
private Integer totalTokens;
@SerializedName("prompt_tokens")
private Integer promptTokens;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
package dev.langchain4j.model.jina;

import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.store.embedding.CosineSimilarity;
import org.junit.jupiter.api.Test;

import java.util.List;

import static java.time.Duration.ofSeconds;
import static java.util.Arrays.asList;
import static org.assertj.core.api.Assertions.assertThat;



public class JinaEmbeddingModelIT {
@Test
public void should_embed_single_text() {

// given
EmbeddingModel model = JinaEmbeddingModel.withApiKey(System.getenv("JINA_AI_API_KEY"));

String text = "hello";

// when
Response<Embedding> response = model.embed(text);

// then
assertThat(response.content().dimension()).isEqualTo(768);

assertThat(response.tokenUsage().inputTokenCount()).isEqualTo(3);
assertThat(response.tokenUsage().outputTokenCount()).isEqualTo(0);
assertThat(response.tokenUsage().totalTokenCount()).isEqualTo(3);
}

@Test
public void should_embed_multiple_segments() {

// given
EmbeddingModel model = JinaEmbeddingModel.builder()
.baseUrl("https://api.jina.ai/")
.apiKey(System.getenv("JINA_AI_API_KEY"))
.modelName("jina-embeddings-v2-base-en")
.timeout(ofSeconds(10))
.maxRetries(2)
.build();

TextSegment segment1 = TextSegment.from("hello");
TextSegment segment2 = TextSegment.from("hi");

// when
Response<List<Embedding>> response = model.embedAll(asList(segment1, segment2));

// then
assertThat(response.content()).hasSize(2);

Embedding embedding1 = response.content().get(0);
assertThat(embedding1.dimension()).isEqualTo(768);

Embedding embedding2 = response.content().get(1);
assertThat(embedding2.dimension()).isEqualTo(768);

assertThat(CosineSimilarity.between(embedding1, embedding2)).isGreaterThan(0.9);

assertThat(response.tokenUsage().inputTokenCount()).isEqualTo(6);
assertThat(response.tokenUsage().outputTokenCount()).isEqualTo(0);
assertThat(response.tokenUsage().totalTokenCount()).isEqualTo(6);
}
}
1 change: 1 addition & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
<module>langchain4j-vertex-ai</module>
<module>langchain4j-vertex-ai-gemini</module>
<module>langchain4j-zhipu-ai</module>
<module>langchain4j-jina</module>

<!-- embedding stores -->
<module>langchain4j-azure-ai-search</module>
Expand Down

0 comments on commit 0ad92d5

Please sign in to comment.