forked from decodingml/llm-twin-course
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
bfd4a1e
commit e0044f2
Showing
17 changed files
with
1,328 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,7 +7,7 @@ authors = [ | |
"Paul Iusztin <[email protected]>", | ||
"Alex Vesa <[email protected]>" | ||
] | ||
package-mode = false | ||
#package-mode = false | ||
readme = "README.md" | ||
|
||
[tool.ruff] | ||
|
@@ -39,6 +39,9 @@ pip = "^24.0" | |
install = "^1.3.5" | ||
comet-ml = "^3.41.0" | ||
ruff = "^0.4.3" | ||
comet-llm = "^2.2.4" | ||
qwak-sdk = "^0.5.69" | ||
pandas = "^2.2.2" | ||
|
||
|
||
[build-system] | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
import uuid | ||
from typing import List, Optional | ||
|
||
import logger_utils | ||
from db.errors import ImproperlyConfigured | ||
from db.mongo import connection | ||
from pydantic import UUID4, BaseModel, ConfigDict, Field | ||
from pymongo import errors | ||
|
||
_database = connection.get_database("scrabble") | ||
|
||
logger = logger_utils.get_logger(__name__) | ||
|
||
|
||
class BaseDocument(BaseModel): | ||
id: UUID4 = Field(default_factory=uuid.uuid4) | ||
|
||
model_config = ConfigDict(from_attributes=True, populate_by_name=True) | ||
|
||
@classmethod | ||
def from_mongo(cls, data: dict): | ||
"""Convert "_id" (str object) into "id" (UUID object).""" | ||
if not data: | ||
return data | ||
|
||
id = data.pop("_id", None) | ||
return cls(**dict(data, id=id)) | ||
|
||
def to_mongo(self, **kwargs) -> dict: | ||
"""Convert "id" (UUID object) into "_id" (str object).""" | ||
exclude_unset = kwargs.pop("exclude_unset", False) | ||
by_alias = kwargs.pop("by_alias", True) | ||
|
||
parsed = self.model_dump( | ||
exclude_unset=exclude_unset, by_alias=by_alias, **kwargs | ||
) | ||
|
||
if "_id" not in parsed and "id" in parsed: | ||
parsed["_id"] = str(parsed.pop("id")) | ||
|
||
return parsed | ||
|
||
def save(self, **kwargs): | ||
collection = _database[self._get_collection_name()] | ||
|
||
try: | ||
result = collection.insert_one(self.to_mongo(**kwargs)) | ||
return result.inserted_id | ||
except errors.WriteError: | ||
logger.exception("Failed to insert document.") | ||
|
||
return None | ||
|
||
@classmethod | ||
def get_or_create(cls, **filter_options) -> Optional[str]: | ||
collection = _database[cls._get_collection_name()] | ||
try: | ||
instance = collection.find_one(filter_options) | ||
if instance: | ||
return str(cls.from_mongo(instance).id) | ||
new_instance = cls(**filter_options) | ||
new_instance = new_instance.save() | ||
return new_instance | ||
except errors.OperationFailure: | ||
logger.exception("Failed to retrieve or create document.") | ||
|
||
return None | ||
|
||
@classmethod | ||
def bulk_insert(cls, documents: List, **kwargs) -> Optional[List[str]]: | ||
collection = _database[cls._get_collection_name()] | ||
try: | ||
result = collection.insert_many( | ||
[doc.to_mongo(**kwargs) for doc in documents] | ||
) | ||
return result.inserted_ids | ||
except errors.WriteError: | ||
logger.exception("Failed to insert documents.") | ||
|
||
return None | ||
|
||
@classmethod | ||
def _get_collection_name(cls): | ||
if not hasattr(cls, "Settings") or not hasattr(cls.Settings, "name"): | ||
raise ImproperlyConfigured( | ||
"Document should define an Settings configuration class with the name of the collection." | ||
) | ||
|
||
return cls.Settings.name | ||
|
||
|
||
class UserDocument(BaseDocument): | ||
first_name: str | ||
last_name: str | ||
|
||
class Settings: | ||
name = "users" | ||
|
||
|
||
class RepositoryDocument(BaseDocument): | ||
name: str | ||
link: str | ||
content: dict | ||
owner_id: str = Field(alias="owner_id") | ||
|
||
class Settings: | ||
name = "repositories" | ||
|
||
|
||
class PostDocument(BaseDocument): | ||
platform: str | ||
content: dict | ||
author_id: str = Field(alias="author_id") | ||
|
||
class Settings: | ||
name = "posts" | ||
|
||
|
||
class ArticleDocument(BaseDocument): | ||
platform: str | ||
link: str | ||
content: dict | ||
author_id: str = Field(alias="author_id") | ||
|
||
class Settings: | ||
name = "articles" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
class ScrabbleException(Exception): | ||
pass | ||
|
||
|
||
class ImproperlyConfigured(ScrabbleException): | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
from pymongo import MongoClient | ||
from pymongo.errors import ConnectionFailure | ||
|
||
import logger_utils | ||
from settings import settings | ||
|
||
logger = logger_utils.get_logger(__name__) | ||
|
||
|
||
class MongoDatabaseConnector: | ||
_instance: MongoClient = None | ||
|
||
def __new__(cls, *args, **kwargs): | ||
if cls._instance is None: | ||
try: | ||
cls._instance = MongoClient(settings.MONGO_DATABASE_HOST) | ||
except ConnectionFailure: | ||
logger.exception( | ||
"Couldn't connect to the database", | ||
database_host=settings.MONGO_DATABASE_HOST, | ||
) | ||
|
||
raise | ||
|
||
logger.info( | ||
"Connection to database successful", uri=settings.MONGO_DATABASE_HOST | ||
) | ||
return cls._instance | ||
|
||
def get_database(self): | ||
return self._instance[settings.MONGO_DATABASE_NAME] | ||
|
||
def close(self): | ||
if self._instance: | ||
self._instance.close() | ||
logger.info( | ||
"Connected to database has been closed.", | ||
uri=settings.MONGO_DATABASE_HOST, | ||
) | ||
|
||
|
||
connection = MongoDatabaseConnector() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
from qdrant_client import QdrantClient | ||
from qdrant_client.http.exceptions import UnexpectedResponse | ||
from qdrant_client.http.models import Batch, Distance, VectorParams | ||
|
||
import logger_utils | ||
from settings import settings | ||
|
||
logger = logger_utils.get_logger(__name__) | ||
|
||
|
||
class QdrantDatabaseConnector: | ||
_instance: QdrantClient = None | ||
|
||
def __init__(self): | ||
if self._instance is None: | ||
try: | ||
if settings.USE_QDRANT_CLOUD: | ||
self._instance = QdrantClient( | ||
url=settings.QDRANT_CLOUD_URL, | ||
api_key=settings.QDRANT_APIKEY, | ||
) | ||
else: | ||
self._instance = QdrantClient( | ||
host=settings.QDRANT_DATABASE_HOST, | ||
port=settings.QDRANT_DATABASE_PORT, | ||
) | ||
|
||
except UnexpectedResponse: | ||
logger.exception( | ||
"Couldn't connect to the database.", | ||
host=settings.QDRANT_DATABASE_HOST, | ||
port=settings.QDRANT_DATABASE_PORT, | ||
) | ||
|
||
raise | ||
|
||
def get_collection(self, collection_name: str): | ||
return self._instance.get_collection(collection_name=collection_name) | ||
|
||
def create_non_vector_collection(self, collection_name: str): | ||
self._instance.create_collection(collection_name=collection_name, vectors_config={}) | ||
|
||
def create_vector_collection(self, collection_name: str): | ||
self._instance.create_collection( | ||
collection_name=collection_name, | ||
vectors_config=VectorParams(size=settings.EMBEDDING_SIZE, distance=Distance.COSINE), | ||
) | ||
|
||
def write_data(self, collection_name: str, points: Batch): | ||
try: | ||
self._instance.upsert(collection_name=collection_name, points=points) | ||
except Exception: | ||
logger.exception("An error occurred while inserting data.") | ||
|
||
raise | ||
|
||
def scroll(self, collection_name: str, limit: int): | ||
return self._instance.scroll(collection_name=collection_name, limit=limit) | ||
|
||
def close(self): | ||
if self._instance: | ||
self._instance.close() | ||
|
||
logger.info("Connected to database has been closed.") | ||
|
||
|
||
connection = QdrantDatabaseConnector() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
import structlog | ||
|
||
|
||
def get_logger(cls: str): | ||
return structlog.get_logger().bind(cls=cls) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
from inference import ModelInference | ||
|
||
if __name__ == '__main__': | ||
tool = ModelInference() | ||
query = """ | ||
Hello my author_id is 1. | ||
Could you please draft a LinkedIn post discussing RAG systems? | ||
I'm particularly interested in how RAG works and how it is integrated with vector DBs and large language models (LLMs). | ||
""" | ||
content = tool.generate_content(query=query) | ||
print(content) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
[tool.poetry] | ||
name = "rag-system" | ||
description = "" | ||
version = "0.1.0" | ||
authors = [ | ||
"vlad_adu <[email protected]>", | ||
"Paul Iusztin <[email protected]>", | ||
"Alex Vesa <[email protected]>" | ||
] | ||
#package-mode = false | ||
readme = "README.md" | ||
|
||
[tool.ruff] | ||
line-length = 88 | ||
select = [ | ||
"F401", | ||
"F403", | ||
] | ||
|
||
|
||
[tool.poetry.dependencies] | ||
python = ">=3.10, <3.12" | ||
pydantic = "^2.6.3" | ||
pydantic-settings = "^2.1.0" | ||
bytewax = "0.18.2" | ||
pika = "^1.3.2" | ||
qdrant-client = "^1.8.0" | ||
unstructured = "^0.12.6" | ||
langchain = "^0.1.13" | ||
sentence-transformers = "^2.6.1" | ||
instructorembedding = "^1.0.1" | ||
numpy = "^1.26.4" | ||
langchain-openai = "^0.1.3" | ||
gdown = "^5.1.0" | ||
pymongo = "^4.7.1" | ||
structlog = "^24.1.0" | ||
rich = "^13.7.1" | ||
pip = "^24.0" | ||
install = "^1.3.5" | ||
comet-ml = "^3.41.0" | ||
ruff = "^0.4.3" | ||
comet-llm = "^2.2.4" | ||
utils = "^1.0.2" | ||
qwak-sdk = "^0.5.69" | ||
pandas = "^2.2.2" | ||
|
||
|
||
[build-system] | ||
requires = ["poetry-core"] | ||
build-backend = "poetry.core.masonry.api" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
def flatten(nested_list: list) -> list: | ||
"""Flatten a list of lists into a single list.""" | ||
|
||
return [item for sublist in nested_list for item in sublist] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
from langchain.text_splitter import ( | ||
RecursiveCharacterTextSplitter, | ||
SentenceTransformersTokenTextSplitter, | ||
) | ||
|
||
from settings import settings | ||
|
||
|
||
def chunk_text(text: str) -> list[str]: | ||
character_splitter = RecursiveCharacterTextSplitter( | ||
separators=["\n\n"], chunk_size=500, chunk_overlap=0 | ||
) | ||
text_split = character_splitter.split_text(text) | ||
|
||
token_splitter = SentenceTransformersTokenTextSplitter( | ||
chunk_overlap=50, | ||
tokens_per_chunk=settings.EMBEDDING_MODEL_MAX_INPUT_LENGTH, | ||
model_name=settings.EMBEDDING_MODEL_ID, | ||
) | ||
chunks = [] | ||
|
||
for section in text_split: | ||
chunks.extend(token_splitter.split_text(section)) | ||
|
||
return chunks |
Oops, something went wrong.