Skip to content

Commit

Permalink
PGVector, remove by ID, same impl as langchain4j#1020 but on top of l…
Browse files Browse the repository at this point in the history
…ast dev. (langchain4j#1113)

langchain4j#301: Add support for remove operations in EmbeddingStore (PGVector)
  • Loading branch information
LangChain4j committed May 23, 2024
1 parent 1348786 commit 358f63c
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 137 deletions.
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
package dev.langchain4j.store.embedding;

import dev.langchain4j.Experimental;
import dev.langchain4j.data.document.Metadata;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.store.embedding.filter.Filter;

import java.util.Collection;
import java.util.List;

/**
Expand Down Expand Up @@ -57,41 +59,42 @@ public interface EmbeddingStore<Embedded> {
List<String> addAll(List<Embedding> embeddings, List<Embedded> embedded);

/**
* Removes an embedding from the store.
* Removes a single embedding from the store by ID.
*
* @param id The unique identifier of the embedding to be removed.
* @param id The unique ID of the embedding to be removed.
*/
@Experimental
default void remove(String id) {
throw new UnsupportedOperationException("Not supported yet.");
}

/**
* Removes all embeddings from the store.
* Removes all embeddings that match the specified IDs from the store.
*
* @param ids A collection of unique IDs of the embeddings to be removed.
*/
default void removeAll() {
@Experimental
default void removeAll(Collection<String> ids) {
throw new UnsupportedOperationException("Not supported yet.");
}

/**
* Removes multiple embeddings from the store.
* Removes all embeddings that match the specified {@link Filter} from the store.
*
* @param ids A list of unique identifiers of the embeddings to be removed.
* @param filter The filter to be applied to the {@link Metadata} of the {@link TextSegment} during removal.
* Only embeddings whose {@code TextSegment}'s {@code Metadata}
* match the {@code Filter} will be removed.
*/
default void removeAll(List<String> ids) {
@Experimental
default void removeAll(Filter filter) {
throw new UnsupportedOperationException("Not supported yet.");
}

/**
* Removes multiple embeddings from the store.
*
* @param filter The filter to be applied to the {@link Metadata} during delete.
* Only {@link TextSegment}s whose {@link Metadata}
* matches the {@link Filter} will be deleted.
* Please note that not all {@link EmbeddingStore}s support this feature yet.
* This is an optional parameter. Default: no filtering, all will be deleted.
* Removes all embeddings from the store.
*/
default void removeAll(Filter filter) {
@Experimental
default void removeAll() {
throw new UnsupportedOperationException("Not supported yet.");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,7 @@

import javax.sql.DataSource;
import java.sql.*;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.UUID;
import java.util.*;
import java.util.stream.IntStream;

import static dev.langchain4j.internal.Utils.*;
Expand Down Expand Up @@ -62,17 +59,17 @@ public class PgVectorEmbeddingStore implements EmbeddingStore<TextSegment> {
* @param indexListSize The IVFFlat number of lists
* @param createTable Should create table automatically
* @param dropTableFirst Should drop table first, usually for testing
* @param metadataStorageConfig The {@link MetadataStorageConfig} config.
* @param metadataStorageConfig The {@link MetadataStorageConfig} config.
*/
@Builder(builderMethodName = "datasourceBuilder", builderClassName = "DatasourceBuilder")
protected PgVectorEmbeddingStore(DataSource datasource,
String table,
Integer dimension,
Boolean useIndex,
Integer indexListSize,
Boolean createTable,
Boolean dropTableFirst,
MetadataStorageConfig metadataStorageConfig) {
String table,
Integer dimension,
Boolean useIndex,
Integer indexListSize,
Boolean createTable,
Boolean dropTableFirst,
MetadataStorageConfig metadataStorageConfig) {
this.datasource = ensureNotNull(datasource, "datasource");
this.table = ensureNotBlank(table, "table");
MetadataStorageConfig config = getOrDefault(metadataStorageConfig, DefaultMetadataStorageConfig.defaultConfig());
Expand All @@ -99,7 +96,7 @@ protected PgVectorEmbeddingStore(DataSource datasource,
* @param indexListSize The IVFFlat number of lists
* @param createTable Should create table automatically
* @param dropTableFirst Should drop table first, usually for testing
* @param metadataStorageConfig The {@link MetadataStorageConfig} config.
* @param metadataStorageConfig The {@link MetadataStorageConfig} config.
*/
@SuppressWarnings("unused")
@Builder
Expand Down Expand Up @@ -129,8 +126,8 @@ private static DataSource createDataSource(String host, Integer port, String use
database = ensureNotBlank(database, "database");

PGSimpleDataSource source = new PGSimpleDataSource();
source.setServerNames(new String[] {host});
source.setPortNumbers(new int[] {port});
source.setServerNames(new String[]{host});
source.setPortNumbers(new int[]{port});
source.setDatabaseName(database);
source.setUser(user);
source.setPassword(password);
Expand All @@ -141,14 +138,15 @@ private static DataSource createDataSource(String host, Integer port, String use

/**
* Initialize metadata table following configuration
* @param dropTableFirst Should drop table first, usually for testing
* @param createTable Should create table automatically
* @param useIndex Should use <a href="https://github.com/pgvector/pgvector#ivfflat">IVFFlat</a> index
* @param dimension The vector dimension
* @param indexListSize The IVFFlat number of lists
*
* @param dropTableFirst Should drop table first, usually for testing
* @param createTable Should create table automatically
* @param useIndex Should use <a href="https://github.com/pgvector/pgvector#ivfflat">IVFFlat</a> index
* @param dimension The vector dimension
* @param indexListSize The IVFFlat number of lists
*/
protected void initTable(Boolean dropTableFirst, Boolean createTable, Boolean useIndex, Integer dimension,
Integer indexListSize) {
Integer indexListSize) {
String query = "init";
try (Connection connection = getConnection(); Statement statement = connection.createStatement()) {
if (dropTableFirst) {
Expand Down Expand Up @@ -192,7 +190,7 @@ public String add(Embedding embedding) {
/**
* Adds a given embedding to the store.
*
* @param id The unique identifier for the embedding to be added.
* @param id The unique identifier for the embedding to be added.
* @param embedding The embedding to be added to the store.
*/
@Override
Expand All @@ -203,7 +201,7 @@ public void add(String id, Embedding embedding) {
/**
* Adds a given embedding and the corresponding content that has been embedded to the store.
*
* @param embedding The embedding to be added to the store.
* @param embedding The embedding to be added to the store.
* @param textSegment Original content that was embedded.
* @return The auto-generated ID associated with the added embedding.
*/
Expand Down Expand Up @@ -231,7 +229,7 @@ public List<String> addAll(List<Embedding> embeddings) {
* Adds multiple embeddings and their corresponding contents that have been embedded to the store.
*
* @param embeddings A list of embeddings to be added to the store.
* @param embedded A list of original contents that were embedded.
* @param embedded A list of original contents that were embedded.
* @return A list of auto-generated IDs associated with the added embeddings.
*/
@Override
Expand All @@ -241,73 +239,51 @@ public List<String> addAll(List<Embedding> embeddings, List<TextSegment> embedde
return ids;
}

/**
* Removes an embedding from the store based on its unique identifier.
*
* @param id The unique identifier of the embedding to be removed.
*/
@Override
public void remove(String id) {
try (Connection connection = getConnection()) {
PreparedStatement statement = connection.prepareStatement(String.format(
"DELETE FROM %s WHERE embedding_id = ?", table));
ensureNotBlank(id, "id");
String sql = String.format("DELETE FROM %s WHERE embedding_id = ?", table);
try (Connection connection = getConnection();
PreparedStatement statement = connection.prepareStatement(sql)) {
statement.setObject(1, UUID.fromString(id));
statement.executeUpdate();
} catch (SQLException e) {
throw new RuntimeException(e);
}
}

/**
* Removes multiple embeddings from the store.
*
* @param ids A list of unique identifiers of the embeddings to be removed.
*/
@Override
public void removeAll(List<String> ids) {
if (ids != null && !ids.isEmpty()) {
try (Connection connection = getConnection()) {
PreparedStatement statement = connection.prepareStatement(String.format(
"DELETE FROM %s WHERE embedding_id = ANY (?)", table));
Array array = connection.createArrayOf("uuid", ids.stream().map(UUID::fromString).toArray());
statement.setArray(1, array);
statement.executeUpdate();
} catch (SQLException e) {
throw new RuntimeException(e);
}
public void removeAll(Collection<String> ids) {
ensureNotEmpty(ids, "ids");
String sql = String.format("DELETE FROM %s WHERE embedding_id = ANY (?)", table);
try (Connection connection = getConnection();
PreparedStatement statement = connection.prepareStatement(sql)) {
Array array = connection.createArrayOf("uuid", ids.stream().map(UUID::fromString).toArray());
statement.setArray(1, array);
statement.executeUpdate();
} catch (SQLException e) {
throw new RuntimeException(e);
}
}

/**
* Removes multiple embeddings from the store.
*
* @param filter The filter to be applied to the {@link Metadata} during delete.
* Only {@link TextSegment}s whose {@link Metadata}
* matches the {@link Filter} will be deleted.
* Please note that not all {@link EmbeddingStore}s support this feature yet.
* This is an optional parameter. Default: no filtering, all will be deleted.
*/
@Override
public void removeAll(Filter filter) {
try (Connection connection = getConnection()) {
String whereClause = (filter == null) ? "" : metadataHandler.whereClause(filter);
whereClause = (whereClause.isEmpty()) ? "" : "WHERE " + whereClause;
PreparedStatement statement = connection.prepareStatement(String.format(
"DELETE FROM %s %s", table, whereClause));
ensureNotNull(filter, "filter");
String whereClause = metadataHandler.whereClause(filter);
String sql = String.format("DELETE FROM %s WHERE %s", table, whereClause);
try (Connection connection = getConnection();
PreparedStatement statement = connection.prepareStatement(sql)) {
statement.executeUpdate();
} catch (SQLException e) {
throw new RuntimeException(e);
}
}

/**
* Removes all embeddings from the store.
*
*/
@Override
public void removeAll() {
try (Connection connection = getConnection()) {
connection.createStatement().executeUpdate(String.format("TRUNCATE TABLE %s", table));
try (Connection connection = getConnection();
Statement statement = connection.createStatement()) {
statement.executeUpdate(String.format("TRUNCATE TABLE %s", table));
} catch (SQLException e) {
throw new RuntimeException(e);
}
Expand Down Expand Up @@ -339,25 +315,25 @@ public EmbeddingSearchResult<TextSegment> search(EmbeddingSearchRequest request)
"WITH temp AS (SELECT (2 - (embedding <=> '%s')) / 2 AS score, embedding_id, embedding, text, " +
"%s FROM %s %s) SELECT * FROM temp WHERE score >= %s ORDER BY score desc LIMIT %s;",
referenceVector, join(",", metadataHandler.columnsNames()), table, whereClause, minScore, maxResults);
try (PreparedStatement selectStmt = connection.prepareStatement(query) ) {
try (ResultSet resultSet = selectStmt.executeQuery()) {
while (resultSet.next()) {
double score = resultSet.getDouble("score");
String embeddingId = resultSet.getString("embedding_id");
try (PreparedStatement selectStmt = connection.prepareStatement(query)) {
try (ResultSet resultSet = selectStmt.executeQuery()) {
while (resultSet.next()) {
double score = resultSet.getDouble("score");
String embeddingId = resultSet.getString("embedding_id");

PGvector vector = (PGvector) resultSet.getObject("embedding");
Embedding embedding = new Embedding(vector.toArray());
PGvector vector = (PGvector) resultSet.getObject("embedding");
Embedding embedding = new Embedding(vector.toArray());

String text = resultSet.getString("text");
TextSegment textSegment = null;
if (isNotNullOrBlank(text)) {
Metadata metadata = metadataHandler.fromResultSet(resultSet);
textSegment = TextSegment.from(text, metadata);
}
result.add(new EmbeddingMatch<>(score, embeddingId, embedding, textSegment));
}
}
}
String text = resultSet.getString("text");
TextSegment textSegment = null;
if (isNotNullOrBlank(text)) {
Metadata metadata = metadataHandler.fromResultSet(resultSet);
textSegment = TextSegment.from(text, metadata);
}
result.add(new EmbeddingMatch<>(score, embeddingId, embedding, textSegment));
}
}
}
} catch (SQLException e) {
throw new RuntimeException(e);
}
Expand Down Expand Up @@ -423,6 +399,7 @@ table, join(",", metadataHandler.columnsNames()),
* Datasource connection
* Creates the vector extension and add the vector type if it does not exist.
* Could be overridden in case extension creation and adding type is done at datasource initialization step.
*
* @return Datasource connection
* @throws SQLException exception
*/
Expand Down
Loading

0 comments on commit 358f63c

Please sign in to comment.