Skip to content

Commit

Permalink
fix: make PR changes pt I
Browse files Browse the repository at this point in the history
  • Loading branch information
marieaurore123 committed Oct 3, 2024
1 parent f0840fb commit 27435e4
Show file tree
Hide file tree
Showing 7 changed files with 65 additions and 54 deletions.
2 changes: 1 addition & 1 deletion rig-lancedb/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@ futures = "0.3.30"

[dev-dependencies]
tokio = "1.40.0"
anyhow = "1.0.89"
anyhow = "1.0.89"
5 changes: 3 additions & 2 deletions rig-lancedb/examples/vector_search_local_ann.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,16 +73,17 @@ async fn main() -> Result<(), anyhow::Error> {
let vector_store = LanceDbVectorStore::new(table, model, "id", search_params).await?;

// See [LanceDB indexing](https://lancedb.github.io/lancedb/concepts/index_ivfpq/#product-quantization) for more information
vector_store
table
.create_index(
&["embedding"],
lancedb::index::Index::IvfPq(
IvfPqIndexBuilder::default()
// This overrides the default distance type of L2.
// Needs to be the same distance type as the one used in search params.
.distance_type(DistanceType::Cosine),
),
&["embedding"],
)
.execute()
.await?;

// Query the index
Expand Down
64 changes: 32 additions & 32 deletions rig-lancedb/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use lancedb::{
index::Index,
query::{QueryBase, VectorQuery},
DistanceType,
};
Expand All @@ -9,7 +8,7 @@ use rig::{
};
use serde::Deserialize;
use serde_json::Value;
use utils::Query;
use utils::QueryToJson;

mod utils;

Expand All @@ -24,7 +23,7 @@ fn serde_to_rig_error(e: serde_json::Error) -> VectorStoreError {
/// # Example
/// ```
/// use std::{env, sync::Arc};
///
/// use arrow_array::RecordBatchIterator;
/// use fixture::{as_record_batch, schema};
/// use rig::{
Expand All @@ -44,23 +43,23 @@ fn serde_to_rig_error(e: serde_json::Error) -> VectorStoreError {
/// // Initialize OpenAI client. Use this to generate embeddings (and generate test data for RAG demo).
/// let openai_api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set");
/// let openai_client = Client::new(&openai_api_key);
///
/// // Select the embedding model and generate our embeddings
/// let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002);
///
/// let embeddings = EmbeddingsBuilder::new(model.clone())
/// .simple_document("doc0", "Definition of *flumbrel (noun)*: a small, seemingly insignificant item that you constantly lose or misplace, such as a pen, hair tie, or remote control.")
/// .simple_document("doc1", "Definition of *zindle (verb)*: to pretend to be working on something important while actually doing something completely unrelated or unproductive")
/// .simple_document("doc2", "Definition of *glimber (adjective)*: describing a state of excitement mixed with nervousness, often experienced before an important event or decision.")
/// .build()
/// .await?;
///
/// // Define search_params params that will be used by the vector store to perform the vector search.
/// let search_params = SearchParams::default();
///
/// // Initialize LanceDB locally.
/// let db = lancedb::connect("data/lancedb-store").execute().await?;
///
/// // Create table with embeddings.
/// let record_batch = as_record_batch(embeddings, model.ndims());
/// let table = db
Expand All @@ -70,9 +69,9 @@ fn serde_to_rig_error(e: serde_json::Error) -> VectorStoreError {
/// )
/// .execute()
/// .await?;
///
/// let vector_store = LanceDbVectorStore::new(table, model, "id", search_params).await?;
///
/// // Query the index
/// let results = vector_store
/// .top_n("My boss says I zindle too much, what does that mean?", 1)
Expand All @@ -86,7 +85,7 @@ fn serde_to_rig_error(e: serde_json::Error) -> VectorStoreError {
/// ))
/// })
/// .collect::<Result<Vec<_>, _>>()?;
///
/// println!("Results: {:?}", results);
/// ```
pub struct LanceDbVectorStore<M: EmbeddingModel> {
Expand Down Expand Up @@ -151,60 +150,71 @@ pub enum SearchType {
Approximate,
}

/// Parameters used to perform a vector search on a LanceDb table.
#[derive(Debug, Clone, Default)]
pub struct SearchParams {
/// Always set the distance_type to match the value used to train the index
/// By default, set to L2
distance_type: Option<DistanceType>,
/// By default, ANN will be used if there is an index on the table.
/// By default, kNN will be used if there is NO index on the table.
/// To use defaults, set to None.
search_type: Option<SearchType>,
/// Set this value only when search type is ANN.
/// See [LanceDb ANN Search](https://lancedb.github.io/lancedb/ann_indexes/#querying-an-ann-index) for more information
nprobes: Option<usize>,
/// Set this value only when search type is ANN.
/// See [LanceDb ANN Search](https://lancedb.github.io/lancedb/ann_indexes/#querying-an-ann-index) for more information
refine_factor: Option<u32>,
/// If set to true, filtering will happen after the vector search instead of before
/// See [LanceDb pre/post filtering](https://lancedb.github.io/lancedb/sql/#pre-and-post-filtering) for more information
post_filter: Option<bool>,
column: Option<String>,
}

impl SearchParams {
/// Sets the distance type of the search params.
/// Always set the distance_type to match the value used to train the index.
/// The default is DistanceType::L2.
pub fn distance_type(mut self, distance_type: DistanceType) -> Self {
self.distance_type = Some(distance_type);
self
}

/// Sets the search type of the search params.
/// By default, ANN will be used if there is an index on the table and kNN will be used if there is NO index on the table.
/// To use the mentioned defaults, do not set the search type.
pub fn search_type(mut self, search_type: SearchType) -> Self {
self.search_type = Some(search_type);
self
}

/// Sets the nprobes of the search params.
/// Only set this value only when the search type is ANN.
/// See [LanceDb ANN Search](https://lancedb.github.io/lancedb/ann_indexes/#querying-an-ann-index) for more information.
pub fn nprobes(mut self, nprobes: usize) -> Self {
self.nprobes = Some(nprobes);
self
}

/// Sets the refine factor of the search params.
/// Only set this value only when search type is ANN.
/// See [LanceDb ANN Search](https://lancedb.github.io/lancedb/ann_indexes/#querying-an-ann-index) for more information.
pub fn refine_factor(mut self, refine_factor: u32) -> Self {
self.refine_factor = Some(refine_factor);
self
}

/// Sets the post filter of the search params.
/// If set to true, filtering will happen after the vector search instead of before.
/// See [LanceDb pre/post filtering](https://lancedb.github.io/lancedb/sql/#pre-and-post-filtering) for more information.
pub fn post_filter(mut self, post_filter: bool) -> Self {
self.post_filter = Some(post_filter);
self
}

/// Sets the column of the search params.
/// Only set this value if there is more than one column that contains lists of floats.
/// If there is only one column of list of floats, this column will be chosen for the vector search automatically.
pub fn column(mut self, column: &str) -> Self {
self.column = Some(column.to_string());
self
}
}

impl<M: EmbeddingModel> LanceDbVectorStore<M> {
/// Create an instance of `LanceDbVectorStore` with an existing table and model.
/// Define the id field name of the table.
/// Define search parameters that will be used to perform vector searches on the table.
pub async fn new(
table: lancedb::Table,
model: M,
Expand All @@ -218,16 +228,6 @@ impl<M: EmbeddingModel> LanceDbVectorStore<M> {
search_params,
})
}

/// Define an index on the specified fields of the lanceDB table for search optimization.
/// Note: it is required to add an index on the column containing the embeddings when performing an ANN type vector search.
pub async fn create_index(
&self,
index: Index,
field_names: &[impl AsRef<str>],
) -> Result<(), lancedb::Error> {
self.table.create_index(field_names, index).execute().await
}
}

impl<M: EmbeddingModel + std::marker::Sync + Send> VectorStoreIndex for LanceDbVectorStore<M> {
Expand Down
27 changes: 15 additions & 12 deletions rig-lancedb/src/utils/deserializer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -313,21 +313,24 @@ impl RecordBatchDeserializer for RecordBatch {
.map(|item| serde_json::to_value(item).map_err(serde_to_rig_error))
.collect()
}
// Not yet fully supported
DataType::BinaryView
| DataType::Utf8View
| DataType::ListView(..)
| DataType::LargeListView(..) => {
todo!()
}
// Currently unstable
DataType::Float16 | DataType::Decimal256(..) => {
todo!()
}
_ => {
println!("Unsupported data type");
Ok(vec![serde_json::Value::Null])
}
| DataType::LargeListView(..) => Err(VectorStoreError::DatastoreError(Box::new(
ArrowError::CastError(format!(
"Data type: {} not yet fully supported",
column.data_type()
)),
))),
DataType::Float16 | DataType::Decimal256(..) => Err(
VectorStoreError::DatastoreError(Box::new(ArrowError::CastError(format!(
"Data type: {} currently unstable",
column.data_type()
)))),
),
_ => Err(VectorStoreError::DatastoreError(Box::new(
ArrowError::CastError(format!("Unsupported data type: {}", column.data_type())),
))),
}
}

Expand Down
4 changes: 2 additions & 2 deletions rig-lancedb/src/utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@ use crate::lancedb_to_rig_error;

/// Trait that facilitates the conversion of columnar data returned by a lanceDb query to serde_json::Value.
/// Used whenever a lanceDb table is queried.
pub trait Query {
pub trait QueryToJson {
async fn execute_query(&self) -> Result<Vec<serde_json::Value>, VectorStoreError>;
}

impl Query for lancedb::query::VectorQuery {
impl QueryToJson for lancedb::query::VectorQuery {
async fn execute_query(&self) -> Result<Vec<serde_json::Value>, VectorStoreError> {
let record_batches = self
.execute()
Expand Down
2 changes: 1 addition & 1 deletion rig-mongodb/examples/vector_search_mongodb.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ async fn main() -> Result<(), anyhow::Error> {

// Create a vector index on our vector store
// IMPORTANT: Reuse the same model that was used to generate the embeddings
let index = vector_store.index(model, "vector_index", SearchParams::new());
let index = vector_store.index(model, "vector_index", SearchParams::default());

// Query the index
let results = index
Expand Down
15 changes: 11 additions & 4 deletions rig-mongodb/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,16 +158,13 @@ impl<M: EmbeddingModel> MongoDbVectorIndex<M> {
/// See [MongoDB Vector Search](https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/) for more information
/// on each of the fields
pub struct SearchParams {
/// Pre-filter
filter: mongodb::bson::Document,
/// Whether to use ANN or ENN search
exact: Option<bool>,
/// Only set this field if exact is set to false
/// Number of nearest neighbors to use during the search
num_candidates: Option<u32>,
}

impl SearchParams {
/// Initializes a new `SearchParams` with default values.
pub fn new() -> Self {
Self {
filter: doc! {},
Expand All @@ -176,16 +173,26 @@ impl SearchParams {
}
}

/// Sets the pre-filter field of the search params.
/// See [MongoDB vector Search](https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/) for more information.
pub fn filter(mut self, filter: mongodb::bson::Document) -> Self {
self.filter = filter;
self
}

/// Sets the exact field of the search params.
/// If exact is true, an ENN vector search will be performed, otherwise, an ANN search will be performed.
/// By default, exact is false.
/// See [MongoDB vector Search](https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/) for more information.
pub fn exact(mut self, exact: bool) -> Self {
self.exact = Some(exact);
self
}

/// Sets the num_candidates field of the search params.
/// Only set this field if exact is set to false.
/// Number of nearest neighbors to use during the search.
/// See [MongoDB vector Search](https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/) for more information.
pub fn num_candidates(mut self, num_candidates: u32) -> Self {
self.num_candidates = Some(num_candidates);
self
Expand Down

0 comments on commit 27435e4

Please sign in to comment.