Skip to content

Commit

Permalink
🚀 feat: Add Atlas MongoDB as an option for Vector Store (danny-avila#21)
Browse files Browse the repository at this point in the history
* Initial commit that starts the FastAPI without failure with MongoDB instead of pgvector

* Expanded AtlasMongoVector Class definition to be compatible with PGVector code; working pass the /query langchain portion

* Add a processing step to remove MongoDB ObjectID (_id) since it is not iterable by jsonable_encoder

* implement GET /ids API for MongoDB

* Fix the GET /documents query parameter

* Get /documents?ids implementation for MongoDB-Altas (initial commit: not complete)

* GET /documents?ids=xxx now returns metadata  properly

* custom_id->file_id: this is a bug even for pgvector

* Implement  DELETE /documents API

* reorganization environment variables

* restore all original code logic for pgvector

* Update README.md: Add Atlas MongoDB section

* Update README.md: further atlas mongo documentation

* Update README.md: add VECTOR_DB_TYPE doc
  • Loading branch information
jinzishuai authored May 11, 2024
1 parent c2dda20 commit ad107dc
Show file tree
Hide file tree
Showing 7 changed files with 201 additions and 47 deletions.
35 changes: 34 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ The following environment variables are required to run the application:
- Note: `OPENAI_API_KEY` will work but `RAG_OPENAI_API_KEY` will override it in order to not conflict with LibreChat setting.
- `RAG_OPENAI_BASEURL`: (Optional) The base URL for your OpenAI API Embeddings
- `RAG_OPENAI_PROXY`: (Optional) Proxy for OpenAI API Embeddings
- `POSTGRES_DB`: (Optional) The name of the PostgreSQL database.
- `VECTOR_DB_TYPE`: (Optional) select vector database type, default to `pgvector`.
- `POSTGRES_DB`: (Optional) The name of the PostgreSQL database, used when `VECTOR_DB_TYPE=pgvector`.
- `POSTGRES_USER`: (Optional) The username for connecting to the PostgreSQL database.
- `POSTGRES_PASSWORD`: (Optional) The password for connecting to the PostgreSQL database.
- `DB_HOST`: (Optional) The hostname or IP address of the PostgreSQL database server.
Expand Down Expand Up @@ -79,6 +80,38 @@ The following environment variables are required to run the application:

Make sure to set these environment variables before running the application. You can set them in a `.env` file or as system environment variables.

### Use Atlas MongoDB as Vector Database

Instead of using the default pgvector, we could use [Atlas MongoDB](https://www.mongodb.com/products/platform/atlas-vector-search) as the vector database. To do so, set the following environment variables

```env
VECTOR_DB_TYPE=atlas-mongo
ATLAS_MONGO_DB_URI=<mongodb+srv://...>
MONGO_VECTOR_COLLECTION=<collection name>
```

The `ATLAS_MONGO_DB_URI` could be the same or different from what is used by LibreChat. Even if it is the same, the `$MONGO_VECTOR_COLLECTION` collection needs to be a completely new one, separate from all collections used by LibreChat. In additional, create a vector search index for `$MONGO_VECTOR_COLLECTION` with the following json:

```json
{
"fields": [
{
"numDimensions": 1536,
"path": "embedding",
"similarity": "cosine",
"type": "vector"
},
{
"path": "file_id",
"type": "filter"
}
]
}
```

Follw one of the [four documented methods](https://www.mongodb.com/docs/atlas/atlas-vector-search/create-index/#procedure) to create the vector index.


### Cloud Installation Settings:

#### AWS:
Expand Down
92 changes: 61 additions & 31 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@
from datetime import datetime

from dotenv import find_dotenv, load_dotenv
from langchain_community.embeddings import HuggingFaceEmbeddings, \
HuggingFaceHubEmbeddings, OllamaEmbeddings
from langchain_community.embeddings import (
HuggingFaceEmbeddings,
HuggingFaceHubEmbeddings,
OllamaEmbeddings,
)
from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings
from starlette.middleware.base import BaseHTTPMiddleware

Expand All @@ -15,27 +18,37 @@
load_dotenv(find_dotenv())


def get_env_variable(var_name: str, default_value: str = None, required: bool = False) -> str:
def get_env_variable(
var_name: str, default_value: str = None, required: bool = False
) -> str:
value = os.getenv(var_name)
if value is None:
if default_value is None and required:
raise ValueError(f"Environment variable '{var_name}' not found.")
return default_value
return value

RAG_HOST = os.getenv('RAG_HOST', '0.0.0.0')
RAG_PORT = int(os.getenv('RAG_PORT', 8000))

RAG_HOST = os.getenv("RAG_HOST", "0.0.0.0")
RAG_PORT = int(os.getenv("RAG_PORT", 8000))

RAG_UPLOAD_DIR = get_env_variable("RAG_UPLOAD_DIR", "./uploads/")
if not os.path.exists(RAG_UPLOAD_DIR):
os.makedirs(RAG_UPLOAD_DIR, exist_ok=True)

VECTOR_DB_TYPE = get_env_variable("VECTOR_DB_TYPE", "pgvector")
POSTGRES_DB = get_env_variable("POSTGRES_DB", "mydatabase")
POSTGRES_USER = get_env_variable("POSTGRES_USER", "myuser")
POSTGRES_PASSWORD = get_env_variable("POSTGRES_PASSWORD", "mypassword")
DB_HOST = get_env_variable("DB_HOST", "db")
DB_PORT = get_env_variable("DB_PORT", "5432")
COLLECTION_NAME = get_env_variable("COLLECTION_NAME", "testcollection")
ATLAS_MONGO_DB_URI = get_env_variable(
"ATLAS_MONGO_DB_URI", "mongodb://127.0.0.1:27018/LibreChat"
)
MONGO_VECTOR_COLLECTION = get_env_variable(
"MONGO_VECTOR_COLLECTION", "vector_collection"
)

CHUNK_SIZE = int(get_env_variable("CHUNK_SIZE", "1500"))
CHUNK_OVERLAP = int(get_env_variable("CHUNK_OVERLAP", "100"))
Expand All @@ -62,6 +75,7 @@ def get_env_variable(var_name: str, default_value: str = None, required: bool =
logger.setLevel(logging.INFO)

if console_json:

class JsonFormatter(logging.Formatter):
def __init__(self):
super(JsonFormatter, self).__init__()
Expand Down Expand Up @@ -96,7 +110,8 @@ def format(self, record):
formatter = JsonFormatter()
else:
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s')
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)

handler = logging.StreamHandler() # or logging.FileHandler("app.log")
handler.setFormatter(formatter)
Expand All @@ -113,12 +128,11 @@ async def dispatch(self, request, call_next):
logger_method = logger.debug

logger_method(
f"Request {request.method} {request.url} - {response.status_code}",
extra={
HTTP_REQ: {"method": request.method,
"url": str(request.url)},
HTTP_RES: {"status_code": response.status_code},
},
f"Request {request.method} {request.url} - {response.status_code}",
extra={
HTTP_REQ: {"method": request.method, "url": str(request.url)},
HTTP_RES: {"status_code": response.status_code},
},
)

return response
Expand All @@ -135,33 +149,38 @@ async def dispatch(self, request, call_next):
RAG_OPENAI_PROXY = get_env_variable("RAG_OPENAI_PROXY", None)
AZURE_OPENAI_API_KEY = get_env_variable("AZURE_OPENAI_API_KEY", "")
RAG_AZURE_OPENAI_API_VERSION = get_env_variable("RAG_AZURE_OPENAI_API_VERSION", None)
RAG_AZURE_OPENAI_API_KEY = get_env_variable("RAG_AZURE_OPENAI_API_KEY", AZURE_OPENAI_API_KEY)
RAG_AZURE_OPENAI_API_KEY = get_env_variable(
"RAG_AZURE_OPENAI_API_KEY", AZURE_OPENAI_API_KEY
)
AZURE_OPENAI_ENDPOINT = get_env_variable("AZURE_OPENAI_ENDPOINT", "")
RAG_AZURE_OPENAI_ENDPOINT = get_env_variable("RAG_AZURE_OPENAI_ENDPOINT", AZURE_OPENAI_ENDPOINT).rstrip("/")
RAG_AZURE_OPENAI_ENDPOINT = get_env_variable(
"RAG_AZURE_OPENAI_ENDPOINT", AZURE_OPENAI_ENDPOINT
).rstrip("/")
HF_TOKEN = get_env_variable("HF_TOKEN", "")
OLLAMA_BASE_URL = get_env_variable("OLLAMA_BASE_URL", "http://ollama:11434")


## Embeddings


def init_embeddings(provider, model):
if provider == "openai":
return OpenAIEmbeddings(
model=model,
api_key=RAG_OPENAI_API_KEY,
openai_api_base=RAG_OPENAI_BASEURL,
openai_proxy=RAG_OPENAI_PROXY
openai_proxy=RAG_OPENAI_PROXY,
)
elif provider == "azure":
return AzureOpenAIEmbeddings(
azure_deployment=model,
api_key=RAG_AZURE_OPENAI_API_KEY,
azure_endpoint=RAG_AZURE_OPENAI_ENDPOINT,
api_version=RAG_AZURE_OPENAI_API_VERSION
api_version=RAG_AZURE_OPENAI_API_VERSION,
)
elif provider == "huggingface":
return HuggingFaceEmbeddings(model_name=model, encode_kwargs={
'normalize_embeddings': True})
return HuggingFaceEmbeddings(
model_name=model, encode_kwargs={"normalize_embeddings": True}
)
elif provider == "huggingfacetei":
return HuggingFaceHubEmbeddings(model=model)
elif provider == "ollama":
Expand All @@ -173,20 +192,20 @@ def init_embeddings(provider, model):
EMBEDDINGS_PROVIDER = get_env_variable("EMBEDDINGS_PROVIDER", "openai").lower()

if EMBEDDINGS_PROVIDER == "openai":
EMBEDDINGS_MODEL = get_env_variable("EMBEDDINGS_MODEL",
"text-embedding-3-small")
EMBEDDINGS_MODEL = get_env_variable("EMBEDDINGS_MODEL", "text-embedding-3-small")

elif EMBEDDINGS_PROVIDER == "azure":
EMBEDDINGS_MODEL = get_env_variable("EMBEDDINGS_MODEL",
"text-embedding-3-small")
EMBEDDINGS_MODEL = get_env_variable("EMBEDDINGS_MODEL", "text-embedding-3-small")

elif EMBEDDINGS_PROVIDER == "huggingface":
EMBEDDINGS_MODEL = get_env_variable("EMBEDDINGS_MODEL",
"sentence-transformers/all-MiniLM-L6-v2")
EMBEDDINGS_MODEL = get_env_variable(
"EMBEDDINGS_MODEL", "sentence-transformers/all-MiniLM-L6-v2"
)

elif EMBEDDINGS_PROVIDER == "huggingfacetei":
EMBEDDINGS_MODEL = get_env_variable("EMBEDDINGS_MODEL",
"http://huggingfacetei:3000")
EMBEDDINGS_MODEL = get_env_variable(
"EMBEDDINGS_MODEL", "http://huggingfacetei:3000"
)

elif EMBEDDINGS_PROVIDER == "ollama":
EMBEDDINGS_MODEL = get_env_variable("EMBEDDINGS_MODEL", "nomic-embed-text")
Expand All @@ -197,14 +216,25 @@ def init_embeddings(provider, model):

logger.info(f"Initialized embeddings of type: {type(embeddings)}")

## Vector store

vector_store = get_vector_store(
# Vector store
if VECTOR_DB_TYPE == "pgvector":
vector_store = get_vector_store(
connection_string=CONNECTION_STRING,
embeddings=embeddings,
collection_name=COLLECTION_NAME,
mode="async",
)
)
elif VECTOR_DB_TYPE == "atlas-mongo":
# atlas-mongo vector:
vector_store = get_vector_store(
connection_string=ATLAS_MONGO_DB_URI,
embeddings=embeddings,
collection_name=MONGO_VECTOR_COLLECTION,
mode="atlas-mongo",
)
else:
raise ValueError(f"Unsupported vector store type: {VECTOR_DB_TYPE}")

retriever = vector_store.as_retriever()

known_source_ext = [
Expand Down
17 changes: 11 additions & 6 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,16 @@
# RAG_EMBEDDING_MODEL,
# RAG_EMBEDDING_MODEL_DEVICE_TYPE,
# RAG_TEMPLATE,
VECTOR_DB_TYPE,
)


@asynccontextmanager
async def lifespan(app: FastAPI):
# Startup logic goes here
await PSQLDatabase.get_pool() # Initialize the pool
await ensure_custom_id_index_on_embedding()
if VECTOR_DB_TYPE == "pgvector":
await PSQLDatabase.get_pool() # Initialize the pool
await ensure_custom_id_index_on_embedding()

yield

Expand Down Expand Up @@ -105,7 +107,10 @@ async def get_all_ids():


def isHealthOK():
return pg_health_check()
if VECTOR_DB_TYPE == "pgvector":
return pg_health_check()
else:
return True


@app.get("/health")
Expand Down Expand Up @@ -137,7 +142,7 @@ async def get_documents_by_ids(ids: list[str] = Query(...)):


@app.delete("/documents")
async def delete_documents(ids: list[str]):
async def delete_documents(ids: list[str] = Query(...)):
try:
if isinstance(vector_store, AsyncPgVector):
existing_ids = await vector_store.get_all_ids()
Expand Down Expand Up @@ -499,11 +504,11 @@ async def query_embeddings_by_file_ids(body: QueryMultipleBody):
vector_store.similarity_search_with_score_by_vector,
embedding,
k=body.k,
filter={"custom_id": {"$in": body.file_ids}},
filter={"file_id": {"$in": body.file_ids}},
)
else:
documents = vector_store.similarity_search_with_score_by_vector(
embedding, k=body.k, filter={"custom_id": {"$in": body.file_ids}}
embedding, k=body.k, filter={"file_id": {"$in": body.file_ids}}
)

return documents
Expand Down
2 changes: 2 additions & 0 deletions requirements.lite.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,5 @@ python-multipart==0.0.9
aiofiles==23.2.1
rapidocr-onnxruntime==1.3.17
opencv-python-headless==4.9.0.80
pymongo==4.6.3
langchain-mongodb==0.1.3
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,5 @@ sentence_transformers==2.5.1
aiofiles==23.2.1
rapidocr-onnxruntime==1.3.17
opencv-python-headless==4.9.0.80
pymongo==4.6.3
langchain-mongodb==0.1.3
Loading

0 comments on commit ad107dc

Please sign in to comment.