Skip to content

Commit

Permalink
fix: Use Qdrant wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
iusztinpaul committed May 25, 2024
1 parent b7064ea commit bb375a5
Show file tree
Hide file tree
Showing 6 changed files with 45 additions and 32 deletions.
10 changes: 9 additions & 1 deletion course/module-5/.env.example
Original file line number Diff line number Diff line change
@@ -1,4 +1,12 @@
OPENAI_API_KEY = "str"
HUGGINGFACE_ACCESS_TOKEN = "str"

COMET_API_KEY = "str"
COMET_WORKSPACE = "str"
COMET_PROJECT = "scrabble"
COMET_PROJECT = "llm-twin-course"

QWAK_DEPLOYMENT_MODEL_ID = "str"
QWAK_DEPLOYMENT_MODEL_API = "str"

QDRANT_CLOUD_URL = "str"
QDRANT_APIKEY = "str"
26 changes: 16 additions & 10 deletions course/module-5/db/qdrant.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
from qdrant_client import QdrantClient
import logger_utils
from qdrant_client import QdrantClient, models
from qdrant_client.http.exceptions import UnexpectedResponse
from qdrant_client.http.models import Batch, Distance, VectorParams

import logger_utils

from settings import settings

logger = logger_utils.get_logger(__name__)


class QdrantDatabaseConnector:
_instance: QdrantClient = None

def __init__(self):
_instance: QdrantClient | None = None
def __init__(self) -> None:
if self._instance is None:
try:
if settings.USE_QDRANT_CLOUD:
Expand All @@ -24,12 +25,12 @@ def __init__(self):
host=settings.QDRANT_DATABASE_HOST,
port=settings.QDRANT_DATABASE_PORT,
)

except UnexpectedResponse:
logger.exception(
"Couldn't connect to the database.",
"Couldn't connect to Qdrant.",
host=settings.QDRANT_DATABASE_HOST,
port=settings.QDRANT_DATABASE_PORT,
url=settings.QDRANT_CLOUD_URL,
)

raise
Expand Down Expand Up @@ -57,6 +58,14 @@ def write_data(self, collection_name: str, points: Batch):
logger.exception("An error occurred while inserting data.")

raise

def search(self, collection_name: str, query_vector: list, query_filter: models.Filter, limit: int) -> list:
return self._instance.search(
collection_name=collection_name,
query_vector=query_vector,
query_filter=query_filter,
limit=limit,
)

def scroll(self, collection_name: str, limit: int):
return self._instance.scroll(collection_name=collection_name, limit=limit)
Expand All @@ -66,6 +75,3 @@ def close(self):
self._instance.close()

logger.info("Connected to database has been closed.")


connection = QdrantDatabaseConnector()
5 changes: 4 additions & 1 deletion course/module-5/finetuning_model/dataset_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,10 @@ def split_data(self, artifact_name: str) -> tuple:
return training_file_path, validation_file_path
except Exception as e:
logging.error(f"Error splitting data: {str(e)}")

raise

def download_dataset(self, file_name: str):
def download_dataset(self, file_name: str) -> tuple:
self.get_artifact(file_name)

return self.split_data(file_name)
3 changes: 2 additions & 1 deletion course/module-5/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
class ModelInference:
def __init__(self) -> None:
self.qwak_client = RealTimeClient(
model_id=settings.MODEL_ID, model_api=settings.MODEL_API
model_id=settings.QWAK_DEPLOYMENT_MODEL_ID,
model_api=settings.QWAK_DEPLOYMENT_MODEL_API,
)
self.template = InferenceTemplate()
self.prompt_monitoring_manager = PromptMonitoringManager()
Expand Down
15 changes: 7 additions & 8 deletions course/module-5/rag/retriever.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import concurrent.futures

from qdrant_client import QdrantClient, models
from sentence_transformers.SentenceTransformer import SentenceTransformer

import logger_utils
import utils
from db.qdrant import QdrantDatabaseConnector
from qdrant_client import models
from rag.query_expanison import QueryExpansion
from rag.reranking import Reranker
from rag.self_query import SelfQuery
from sentence_transformers.SentenceTransformer import SentenceTransformer
from settings import settings

logger = logger_utils.get_logger(__name__)
Expand All @@ -18,11 +18,8 @@ class VectorRetriever:
Class for retrieving vectors from a Vector store in a RAG system using query expansion and Multitenancy search.
"""

def __init__(self, query: str):
self._client = QdrantClient(
host=settings.QDRANT_DATABASE_HOST,
port=settings.QDRANT_DATABASE_PORT,
)
def __init__(self, query: str) -> None:
self._client = QdrantDatabaseConnector()
self.query = query
self._embedder = SentenceTransformer(settings.EMBEDDING_MODEL_ID)
self._query_expander = QueryExpansion()
Expand All @@ -33,7 +30,9 @@ def _search_single_query(
self, generated_query: str, metadata_filter_value: str, k: int
):
assert k > 3, "k should be greater than 3"

query_vector = self._embedder.encode(generated_query).tolist()

vectors = [
self._client.search(
collection_name="vector_posts",
Expand Down
18 changes: 7 additions & 11 deletions course/module-5/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,16 @@
class AppSettings(BaseSettings):
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8")

# Embeddings config
# Embeddings config
EMBEDDING_MODEL_ID: str = "sentence-transformers/all-MiniLM-L6-v2"
EMBEDDING_MODEL_MAX_INPUT_LENGTH: int = 256
EMBEDDING_SIZE: int = 384
EMBEDDING_MODEL_DEVICE: str = "cpu"

# OpenAI config
OPENAI_MODEL_ID: str = "gpt-4-1106-preview"
OPENAI_API_KEY: str | None = None

# MongoDB config
MONGO_DATABASE_HOST: str = "mongodb://localhost:30001,localhost:30002,localhost:30003/?replicaSet=my-replica-set"
MONGO_DATABASE_NAME: str = "scrabble"

# QdrantDB config
QDRANT_DATABASE_HOST: str = "localhost"
QDRANT_DATABASE_PORT: int = 6333
Expand All @@ -34,16 +30,16 @@ class AppSettings(BaseSettings):
RABBITMQ_PORT: int = 5673

# CometML config
COMET_API_KEY: str | None = None
COMET_WORKSPACE: str = "vladadu"
COMET_PROJECT: str = "llm-twin"
COMET_API_KEY: str
COMET_WORKSPACE: str
COMET_PROJECT: str = "llm-twin-course"

# LLM Model config
TOKENIZERS_PARALLELISM: str = "false"
HUGGINGFACE_ACCESS_TOKEN: str | None = None
MODEL_TYPE: str = "mistralai/Mistral-7B-Instruct-v0.1"
MODEL_ID: str = "copywriter_model"
MODEL_API: str = (
QWAK_DEPLOYMENT_MODEL_ID: str = "copywriter_model"
QWAK_DEPLOYMENT_MODEL_API: str = (
"https://models.llm-twin.qwak.ai/v1/copywriter_model/default/predict"
)

Expand Down

0 comments on commit bb375a5

Please sign in to comment.