Skip to content

Commit

Permalink
feat: update examples to use new version of VectorStoreIndex trait
Browse files Browse the repository at this point in the history
  • Loading branch information
marieaurore123 committed Oct 2, 2024
1 parent 921b313 commit 0050925
Show file tree
Hide file tree
Showing 9 changed files with 272 additions and 666 deletions.
67 changes: 67 additions & 0 deletions rig-lancedb/examples/fixtures/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
use std::sync::Arc;

use arrow_array::{types::Float64Type, ArrayRef, FixedSizeListArray, RecordBatch, StringArray};
use lancedb::arrow::arrow_schema::{DataType, Field, Fields, Schema};
use rig::embeddings::DocumentEmbeddings;

// Schema of table in LanceDB.
pub fn schema(dims: usize) -> Schema {
Schema::new(Fields::from(vec![
Field::new("id", DataType::Utf8, false),
Field::new("content", DataType::Utf8, false),
Field::new(
"embedding",
DataType::FixedSizeList(
Arc::new(Field::new("item", DataType::Float64, true)),
dims as i32,
),
false,
),
]))
}

// Convert DocumentEmbeddings objects to a RecordBatch.
pub fn as_record_batch(
records: Vec<DocumentEmbeddings>,
dims: usize,
) -> Result<RecordBatch, lancedb::arrow::arrow_schema::ArrowError> {
let id = StringArray::from_iter_values(
records
.iter()
.flat_map(|record| (0..record.embeddings.len()).map(|i| format!("{}-{i}", record.id)))
.collect::<Vec<_>>(),
);

let content = StringArray::from_iter_values(
records
.iter()
.flat_map(|record| {
record
.embeddings
.iter()
.map(|embedding| embedding.document.clone())
})
.collect::<Vec<_>>(),
);

let embedding = FixedSizeListArray::from_iter_primitive::<Float64Type, _, _>(
records
.into_iter()
.flat_map(|record| {
record
.embeddings
.into_iter()
.map(|embedding| embedding.vec.into_iter().map(Some).collect::<Vec<_>>())
.map(Some)
.collect::<Vec<_>>()
})
.collect::<Vec<_>>(),
dims as i32,
);

RecordBatch::try_from_iter(vec![
("id", Arc::new(id) as ArrayRef),
("content", Arc::new(content) as ArrayRef),
("embedding", Arc::new(embedding) as ArrayRef),
])
}
71 changes: 51 additions & 20 deletions rig-lancedb/examples/vector_search_local_ann.rs
Original file line number Diff line number Diff line change
@@ -1,29 +1,35 @@
use std::env;
use std::{env, sync::Arc};

use arrow_array::RecordBatchIterator;
use fixture::{as_record_batch, schema};
use lancedb::{index::vector::IvfPqIndexBuilder, DistanceType};
use rig::{
completion::Prompt,
embeddings::EmbeddingsBuilder,
embeddings::{EmbeddingModel, EmbeddingsBuilder},
providers::openai::{Client, TEXT_EMBEDDING_ADA_002},
vector_store::{VectorStore, VectorStoreIndexDyn},
vector_store::VectorStoreIndexDyn,
};
use rig_lancedb::{LanceDbVectorStore, SearchParams};
use serde::Deserialize;

#[path = "./fixtures/lib.rs"]
mod fixture;

#[derive(Deserialize, Debug)]
pub struct VectorSearchResult {
pub id: String,
pub content: String,
}

#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
// Initialize OpenAI client. Use this to generate embeddings (and generate test data for RAG demo).
let openai_api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY not set");
let openai_client = Client::new(&openai_api_key);

// Select the embedding model and generate our embeddings
// Select an embedding model.
let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002);

let search_params = SearchParams::default().distance_type(DistanceType::Cosine);

// Initialize LanceDB locally.
let db = lancedb::connect("data/lancedb-store").execute().await?;
let mut vector_store = LanceDbVectorStore::new(&db, &model, &search_params).await?;

// Generate test data for RAG demo
let agent = openai_client
.agent("gpt-4o")
Expand All @@ -39,6 +45,7 @@ async fn main() -> Result<(), anyhow::Error> {
definitions.extend(definitions.clone());
definitions.extend(definitions.clone());

// Generate embeddings for the test data.
let embeddings = EmbeddingsBuilder::new(model.clone())
.simple_document("doc0", "Definition of *flumbrel (noun)*: a small, seemingly insignificant item that you constantly lose or misplace, such as a pen, hair tie, or remote control.")
.simple_document("doc1", "Definition of *zindle (verb)*: to pretend to be working on something important while actually doing something completely unrelated or unproductive")
Expand All @@ -47,26 +54,50 @@ async fn main() -> Result<(), anyhow::Error> {
.build()
.await?;

// Add embeddings to vector store
// vector_store.add_documents(embeddings).await?;
// Define search_params params that will be used by the vector store to perform the vector search.
let search_params = SearchParams::default().distance_type(DistanceType::Cosine);

// Initialize LanceDB locally.
let db = lancedb::connect("data/lancedb-store").execute().await?;

// Create table with embeddings.
let record_batch = as_record_batch(embeddings, model.ndims());
let table = db
.create_table(
"definitions",
RecordBatchIterator::new(vec![record_batch], Arc::new(schema(model.ndims()))),
)
.execute()
.await?;

let vector_store = LanceDbVectorStore::new(table, model, "id", search_params).await?;

// See [LanceDB indexing](https://lancedb.github.io/lancedb/concepts/index_ivfpq/#product-quantization) for more information
vector_store
.create_index(lancedb::index::Index::IvfPq(
IvfPqIndexBuilder::default()
// This overrides the default distance type of L2.
// Needs to be the same distance type as the one used in search params.
.distance_type(DistanceType::Cosine),
))
.create_index(
lancedb::index::Index::IvfPq(
IvfPqIndexBuilder::default()
// This overrides the default distance type of L2.
// Needs to be the same distance type as the one used in search params.
.distance_type(DistanceType::Cosine),
),
&["embedding"],
)
.await?;

// Query the index
let results = vector_store
.top_n("My boss says I zindle too much, what does that mean?", 1)
.await?
.into_iter()
.map(|(score, id, doc)| (score, id, doc))
.collect::<Vec<_>>();
.map(|(score, id, doc)| {
anyhow::Ok((
score,
id,
serde_json::from_value::<VectorSearchResult>(doc)?,
))
})
.collect::<Result<Vec<_>, _>>()?;

println!("Results: {:?}", results);

Expand Down
42 changes: 28 additions & 14 deletions rig-lancedb/examples/vector_search_local_enn.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
use std::env;
use std::{env, sync::Arc};

use arrow_array::RecordBatchIterator;
use fixture::{as_record_batch, schema};
use rig::{
embeddings::EmbeddingsBuilder,
embeddings::{EmbeddingModel, EmbeddingsBuilder},
providers::openai::{Client, TEXT_EMBEDDING_ADA_002},
vector_store::{VectorStore, VectorStoreIndexDyn},
vector_store::VectorStoreIndexDyn,
};
use rig_lancedb::{LanceDbVectorStore, SearchParams};
use serde::Deserialize;

#[path = "./fixtures/lib.rs"]
mod fixture;

#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
Expand All @@ -16,27 +22,35 @@ async fn main() -> Result<(), anyhow::Error> {
// Select the embedding model and generate our embeddings
let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002);

// Initialize LanceDB locally.
let db = lancedb::connect("data/lancedb-store").execute().await?;
let mut vector_store = LanceDbVectorStore::new(&db, &model, &SearchParams::default()).await?;

let embeddings = EmbeddingsBuilder::new(model.clone())
.simple_document("doc0", "Definition of *flumbrel (noun)*: a small, seemingly insignificant item that you constantly lose or misplace, such as a pen, hair tie, or remote control.")
.simple_document("doc1", "Definition of *zindle (verb)*: to pretend to be working on something important while actually doing something completely unrelated or unproductive")
.simple_document("doc2", "Definition of *glimber (adjective)*: describing a state of excitement mixed with nervousness, often experienced before an important event or decision.")
.build()
.await?;

// Add embeddings to vector store
// vector_store.add_documents(embeddings).await?;
// Define search_params params that will be used by the vector store to perform the vector search.
let search_params = SearchParams::default();

// Initialize LanceDB locally.
let db = lancedb::connect("data/lancedb-store").execute().await?;

// Create table with embeddings.
let record_batch = as_record_batch(embeddings, model.ndims());
let table = db
.create_table(
"definitions",
RecordBatchIterator::new(vec![record_batch], Arc::new(schema(model.ndims()))),
)
.execute()
.await?;

let vector_store = LanceDbVectorStore::new(table, model, "id", search_params).await?;

// Query the index
let results = vector_store
.top_n("My boss says I zindle too much, what does that mean?", 1)
.await?
.into_iter()
.map(|(score, id, doc)| (score, id, doc))
.collect::<Vec<_>>();
.top_n_ids("My boss says I zindle too much, what does that mean?", 1)
.await?;

println!("Results: {:?}", results);

Expand Down
83 changes: 56 additions & 27 deletions rig-lancedb/examples/vector_search_s3_ann.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,28 @@
use std::env;
use std::{env, sync::Arc};

use arrow_array::RecordBatchIterator;
use fixture::{as_record_batch, schema};
use lancedb::{index::vector::IvfPqIndexBuilder, DistanceType};
use rig::{
completion::Prompt,
embeddings::EmbeddingsBuilder,
embeddings::{EmbeddingModel, EmbeddingsBuilder},
providers::openai::{Client, TEXT_EMBEDDING_ADA_002},
vector_store::{VectorStore, VectorStoreIndexDyn},
vector_store::VectorStoreIndexDyn,
};
use rig_lancedb::{LanceDbVectorStore, SearchParams};
use serde::Deserialize;

#[path = "./fixtures/lib.rs"]
mod fixture;

#[derive(Deserialize, Debug)]
pub struct VectorSearchResult {
pub id: String,
pub content: String,
}

// Note: see docs to deploy LanceDB on other cloud providers such as google and azure.
// https://lancedb.github.io/lancedb/guides/storage/

#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
// Initialize OpenAI client. Use this to generate embeddings (and generate test data for RAG demo).
Expand All @@ -21,23 +32,13 @@ async fn main() -> Result<(), anyhow::Error> {
// Select the embedding model and generate our embeddings
let model = openai_client.embedding_model(TEXT_EMBEDDING_ADA_002);

let search_params = SearchParams::default().distance_type(DistanceType::Cosine);

// Initialize LanceDB on S3.
// Note: see below docs for more options and IAM permission required to read/write to S3.
// https://lancedb.github.io/lancedb/guides/storage/#aws-s3
let db = lancedb::connect("s3://lancedb-test-829666124233")
.execute()
.await?;
let mut vector_store = LanceDbVectorStore::new(&db, &model, &search_params).await?;

// Generate test data for RAG demo
let agent = openai_client
.agent("gpt-4o")
.preamble("Return the answer as JSON containing a list of strings in the form: `Definition of {generated_word}: {generated definition}`. Return ONLY the JSON string generated, nothing else.")
.build();
let response = agent
.prompt("Invent at least 100 words and their definitions")
.prompt("Invent 100 words and their definitions")
.await?;
let mut definitions: Vec<String> = serde_json::from_str(&response)?;

Expand All @@ -46,34 +47,62 @@ async fn main() -> Result<(), anyhow::Error> {
definitions.extend(definitions.clone());
definitions.extend(definitions.clone());

let embeddings: Vec<rig::embeddings::DocumentEmbeddings> = EmbeddingsBuilder::new(model.clone())
// Generate embeddings for the test data.
let embeddings = EmbeddingsBuilder::new(model.clone())
.simple_document("doc0", "Definition of *flumbrel (noun)*: a small, seemingly insignificant item that you constantly lose or misplace, such as a pen, hair tie, or remote control.")
.simple_document("doc1", "Definition of *zindle (verb)*: to pretend to be working on something important while actually doing something completely unrelated or unproductive")
.simple_document("doc2", "Definition of *glimber (adjective)*: describing a state of excitement mixed with nervousness, often experienced before an important event or decision.")
.simple_documents(definitions.clone().into_iter().enumerate().map(|(i, def)| (format!("doc{}", i+3), def)).collect())
.build()
.await?;

// Add embeddings to vector store
// vector_store.add_documents(embeddings).await?;
// Define search_params params that will be used by the vector store to perform the vector search.
let search_params = SearchParams::default().distance_type(DistanceType::Cosine);

// Initialize LanceDB on S3.
// Note: see below docs for more options and IAM permission required to read/write to S3.
// https://lancedb.github.io/lancedb/guides/storage/#aws-s3
let db = lancedb::connect("s3://lancedb-test-829666124233")
.execute()
.await?;
// Create table with embeddings.
let record_batch = as_record_batch(embeddings, model.ndims());
let table = db
.create_table(
"definitions",
RecordBatchIterator::new(vec![record_batch], Arc::new(schema(model.ndims()))),
)
.execute()
.await?;

let vector_store = LanceDbVectorStore::new(table, model, "id", search_params).await?;

// See [LanceDB indexing](https://lancedb.github.io/lancedb/concepts/index_ivfpq/#product-quantization) for more information
vector_store
.create_index(lancedb::index::Index::IvfPq(
IvfPqIndexBuilder::default()
// This overrides the default distance type of L2.
// Needs to be the same distance type as the one used in search params.
.distance_type(DistanceType::Cosine),
))
.create_index(
lancedb::index::Index::IvfPq(
IvfPqIndexBuilder::default()
// This overrides the default distance type of L2.
// Needs to be the same distance type as the one used in search params.
.distance_type(DistanceType::Cosine),
),
&["embedding"],
)
.await?;

// Query the index
let results = vector_store
.top_n("My boss says I zindle too much, what does that mean?", 1)
.top_n("I'm always looking for my phone, I always seem to forget it in the most counterintuitive places. What's the word for this feeling?", 1)
.await?
.into_iter()
.map(|(score, id, doc)| (score, id, doc))
.collect::<Vec<_>>();
.map(|(score, id, doc)| {
anyhow::Ok((
score,
id,
serde_json::from_value::<VectorSearchResult>(doc)?,
))
})
.collect::<Result<Vec<_>, _>>()?;

println!("Results: {:?}", results);

Expand Down
Loading

0 comments on commit 0050925

Please sign in to comment.