-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
29 changed files
with
572 additions
and
26 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from .framework import RAGAFramework | ||
from .document import RAGADocument as RAGADocument | ||
|
||
__all__ = ["RAGAFramework", "RAGADocument"] | ||
__version__ = "0.1.0" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
from typing import Dict, Optional | ||
|
||
|
||
class RAGADocument: | ||
def __init__(self, content: str, metadata: Optional[Dict] = None): | ||
self.content = content | ||
self.metadata = metadata or {} | ||
|
||
def __repr__(self): | ||
return f"Document(content={self.content[:50]}..., metadata={self.metadata})" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
from .base import BaseEmbedding | ||
from .openai_embeddings import OpenAIEmbedding | ||
from .ollama_embeddngs import OllamaEmbeddings | ||
from .le_mistral_embeddings import LeMistralEmbeddings | ||
|
||
|
||
def get_embedding_model( | ||
provider: str, api_key: str, model: str = "ollama" | ||
) -> BaseEmbedding: | ||
if provider == "ollama": | ||
return OllamaEmbeddings(api_key, model) | ||
elif provider == "mistral": | ||
return LeMistralEmbeddings(api_key, model) | ||
elif provider == "openai": | ||
return OpenAIEmbedding(api_key, model) | ||
else: | ||
raise ValueError(f"Unsupported embedding provider: {provider}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
from abc import ABC, abstractmethod | ||
from typing import List | ||
|
||
class BaseEmbedding(ABC): | ||
@abstractmethod | ||
def embed(self, text: str) -> List[float]: | ||
pass | ||
|
||
@abstractmethod | ||
def embed_batch(self, texts: List[str]) -> List[List[float]]: | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
from mistralai.models.embeddingresponse import EmbeddingResponse | ||
from mistralai import Mistral | ||
from typing import List, Union | ||
from .base import BaseEmbedding | ||
|
||
|
||
class LeMistralEmbeddings(BaseEmbedding): | ||
def __init__(self, api_key: str = "xyz", model_name: str = "nomic-embed-text"): | ||
self.api_key = api_key | ||
self.model = model_name | ||
self.mistral_client = Mistral(api_key=self.api_key) | ||
|
||
def embed(self, text: str) -> List[float]: | ||
response: Union[EmbeddingResponse, None] = ( | ||
self.mistral_client.embeddings.create(model=self.model, inputs=[text]) | ||
) | ||
if not response: | ||
return [] | ||
|
||
if not response.data[0].embedding: | ||
return [] | ||
return response.data[0].embedding | ||
|
||
def embed_batch(self, texts: List[str]) -> List[List[float]]: | ||
response: Union[EmbeddingResponse, None] = ( | ||
self.mistral_client.embeddings.create(model=self.model, inputs=texts) | ||
) | ||
if not response: | ||
return [] | ||
|
||
_embeddings = [] | ||
for item in response.data: | ||
if item.embedding is not None: | ||
_embeddings.append(item.embedding) | ||
else: | ||
_embeddings.append([]) | ||
return _embeddings |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
import ollama | ||
from typing import List | ||
from .base import BaseEmbedding | ||
|
||
|
||
class OllamaEmbeddings(BaseEmbedding): | ||
def __init__(self, api_key: str = "xyz", model_name: str = "nomic-embed-text"): | ||
self.api_key = api_key | ||
self.model = model_name | ||
|
||
def embed(self, text: str) -> List[float]: | ||
response = ollama.embed(model=self.model, input=[text]) | ||
return response["embeddings"][0] | ||
|
||
def embed_batch(self, texts: List[str]) -> List[List[float]]: | ||
response = ollama.embed(model=self.model, input=texts) | ||
return response["embeddings"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
import openai | ||
from typing import List | ||
from .base import BaseEmbedding | ||
|
||
|
||
class OpenAIEmbedding(BaseEmbedding): | ||
def __init__(self, api_key: str, model: str = "text-embedding-ada-002"): | ||
self.api_key = api_key | ||
self.model = model | ||
openai.api_key = self.api_key | ||
|
||
def embed(self, text: str) -> List[float]: | ||
response = openai.Embedding.create(input=[text], model=self.model) | ||
return response["data"][0]["embedding"] | ||
|
||
def embed_batch(self, texts: List[str]) -> List[List[float]]: | ||
response = openai.Embedding.create(input=texts, model=self.model) | ||
return [item["embedding"] for item in response["data"]] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
from typing import List, Literal, Optional, Dict | ||
from .document import RAGADocument | ||
from .embeddings import get_embedding_model | ||
from .vectorstores import get_vectorstore | ||
from .llms import get_llm | ||
from .utils.text_splitter import SimpleTextSplitter | ||
from .retrievers.similarity import SimilarityRetriever | ||
from .session import Session | ||
|
||
|
||
EmbeddingTypes = Literal["ollama", "mistral", "openai"] | ||
LLMTypes = Literal["ollama", "mistral", "openai"] | ||
VectorDBTypes = Literal["chroma", "faiss"] | ||
# LLMModelNames = Literal["llama3.1", "mistral", "mistral-meno", "gemma2", "gemeni-pro"] | ||
|
||
|
||
class RAGAFramework: | ||
def __init__( | ||
self, | ||
embedding_provider: EmbeddingTypes, | ||
vectorstore_provider: VectorDBTypes, | ||
llm_provider: LLMTypes, | ||
api_keys: Dict[str, str], | ||
embedding_model: str = "nomic-embed-text", | ||
llm_model: str = "mistral", | ||
system_prompt: Optional[str] = None, | ||
text_splitter: Optional[SimpleTextSplitter] = None, | ||
): | ||
""" | ||
RAGAFramework class - A multi-purpose AI assistant framework that utilizes embedding, vectorstore, and LLM services in a unified manner. | ||
Provides methods to add documents, query the model, and customize the system prompt. The framework uses context from the stored documents to generate responses. | ||
Usage example: | ||
``` | ||
from RAGAFramework import RAGAFramework | ||
# Initialize the RAGAFramework instance with certain configurations | ||
framework = RAGAFramework(embedding_provider="HuggingFace", vectorstore_provider="DrQA", llm_provider="RiversideCode") | ||
# Add documents to the framework | ||
# Assuming documents is a list of Document objects as per the SDK's definition | ||
framework.add_documents(documents) | ||
# Set a custom system prompt (optional, defaults to the default one) | ||
framework.system_prompt = "You are an assistant that helps with specific tasks. Please perform the task asked by the user." | ||
# Query the framework with a given question | ||
response = framework.query("What is the capital of France?") | ||
# Print the generated response from the query | ||
print(response) | ||
``` | ||
""" | ||
self.session = Session() | ||
self.session.set_api_keys(api_keys) | ||
|
||
self.embedding = get_embedding_model( | ||
embedding_provider, | ||
self.session.get_api_key(embedding_provider), | ||
embedding_model, | ||
) | ||
self.vectorstore = get_vectorstore( | ||
vectorstore_provider, self.session.get_api_key(vectorstore_provider) | ||
) | ||
self.llm = get_llm( | ||
llm_provider, self.session.get_api_key(llm_provider), llm_model | ||
) | ||
|
||
self.system_prompt = system_prompt or self._default_system_prompt() | ||
self.text_splitter = text_splitter or SimpleTextSplitter() | ||
self.retriever = SimilarityRetriever(self.vectorstore, self.embedding) | ||
|
||
def _default_system_prompt(self) -> str: | ||
return ( | ||
"You are a helpful AI assistant. Use the provided context to answer " | ||
"questions. If you're unsure or the answer isn't in the context, " | ||
"say you don't know." | ||
) | ||
|
||
def add_documents(self, documents: List[RAGADocument]): | ||
split_docs = [] | ||
for doc in documents: | ||
chunks = self.text_splitter.split_text(doc.content) | ||
split_docs.extend([RAGADocument(chunk, doc.metadata) for chunk in chunks]) | ||
|
||
embeddings = self.embedding.embed_batch([doc.content for doc in split_docs]) | ||
self.vectorstore.add_documents(split_docs, embeddings) | ||
|
||
def query(self, query: str, k: int = 4) -> str: | ||
retrieved_docs = self.retriever.retrieve(query, k) | ||
context = "\n\n".join([doc.content for doc in retrieved_docs]) | ||
|
||
prompt = f"{self.system_prompt}\n\nContext:\n{context}\n\nQuestion: {query}\n\nAnswer:" | ||
|
||
return self.llm.generate(prompt) | ||
|
||
def __enter__(self): | ||
return self | ||
|
||
def __exit__(self, exc_type, exc_val, exc_tb): | ||
self.session.clear_api_keys() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
from .base import BaseLLM | ||
from .openai import OpenAILLM | ||
from .gemini import GeminiLLM | ||
from .ollama import OllamaLLM | ||
from .le_mistral import LeMistralLLM | ||
|
||
|
||
def get_llm(provider: str, api_key: str, model: str = "ollama") -> BaseLLM: | ||
if provider == "ollama": | ||
return OllamaLLM(api_key, model) | ||
elif provider == "mistral": | ||
return LeMistralLLM(api_key, model) | ||
elif provider == "openai": | ||
return OpenAILLM(api_key, model) | ||
elif provider == "gemini": | ||
return GeminiLLM(api_key, model) | ||
else: | ||
raise ValueError(f"Unsupported LLM provider: {provider}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
from abc import ABC, abstractmethod | ||
|
||
class BaseLLM(ABC): | ||
@abstractmethod | ||
def generate(self, prompt: str) -> str: | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
import google.generativeai as genai | ||
from .base import BaseLLM | ||
|
||
|
||
class GeminiLLM(BaseLLM): | ||
def __init__(self, api_key: str, model: str = "gemini-pro"): | ||
genai.configure(api_key=api_key) | ||
self.model = genai.GenerativeModel(model) | ||
|
||
def generate(self, prompt: str) -> str: | ||
response = self.model.generate_content(prompt) | ||
return response.text |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
from typing import Union | ||
from .base import BaseLLM | ||
from mistralai import ChatCompletionResponse, Mistral | ||
|
||
|
||
class LeMistralLLM(BaseLLM): | ||
def __init__(self, api_key: str = "xyz", model: str = "llama3.1"): | ||
self.api_key = api_key | ||
self.model = model | ||
self.mistral_client = Mistral(api_key=self.api_key) | ||
|
||
def generate(self, prompt: str) -> str: | ||
response: Union[None, ChatCompletionResponse] = ( | ||
self.mistral_client.chat.complete( | ||
model=self.model, | ||
messages=[{"role": "user", "content": prompt}], | ||
stream=False, | ||
) | ||
) | ||
if ( | ||
not response | ||
or not response.choices | ||
or not response.choices[0].message.content | ||
): | ||
return "" | ||
return response.choices[0].message.content |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
import ollama | ||
from .base import BaseLLM | ||
|
||
|
||
class OllamaLLM(BaseLLM): | ||
def __init__(self, api_key: str = "xyz", model: str = "llama3.1"): | ||
self.api_key = api_key | ||
self.model = model | ||
|
||
def generate(self, prompt: str) -> str: | ||
response = ollama.chat( | ||
model=self.model, | ||
messages=[{"role": "user", "content": prompt}], | ||
stream=False, | ||
) | ||
return response["message"]["content"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
import openai | ||
from .base import BaseLLM | ||
|
||
class OpenAILLM(BaseLLM): | ||
def __init__(self, api_key: str, model: str = "gpt-3.5-turbo"): | ||
self.api_key = api_key | ||
self.model = model | ||
openai.api_key = self.api_key | ||
|
||
def generate(self, prompt: str) -> str: | ||
response = openai.ChatCompletion.create( | ||
model=self.model, | ||
messages=[{"role": "user", "content": prompt}] | ||
) | ||
return response.choices[0].message.content |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from .base import BaseRetriever | ||
from .similarity import SimilarityRetriever | ||
|
||
__all__ = ["BaseRetriever", "SimilarityRetriever"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
from abc import ABC, abstractmethod | ||
from typing import List | ||
from ..document import RAGADocument | ||
|
||
|
||
class BaseRetriever(ABC): | ||
@abstractmethod | ||
def retrieve(self, query: str, k: int = 4) -> List[RAGADocument]: | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
from typing import List | ||
from .base import BaseRetriever | ||
from ..vectorstores.base import BaseVectorStore | ||
from ..embeddings.base import BaseEmbedding | ||
from ..document import RAGADocument | ||
|
||
|
||
class SimilarityRetriever(BaseRetriever): | ||
def __init__(self, vectorstore: BaseVectorStore, embedding: BaseEmbedding): | ||
self.vectorstore = vectorstore | ||
self.embedding = embedding | ||
|
||
def retrieve(self, query: str, k: int = 4) -> List[RAGADocument]: | ||
query_embedding = self.embedding.embed(query) | ||
return self.vectorstore.similarity_search(query_embedding, k) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
import os | ||
from typing import Dict | ||
import uuid | ||
|
||
|
||
class Session: | ||
def __init__(self): | ||
self.session_id = str(uuid.uuid4()) | ||
|
||
def set_api_keys(self, api_keys: Dict[str, str]): | ||
for provider, key in api_keys.items(): | ||
os.environ[f"{self.session_id}_{provider.upper()}_API_KEY"] = key | ||
|
||
def get_api_key(self, provider: str) -> str: | ||
key = os.environ.get(f"{self.session_id}_{provider.upper()}_API_KEY") | ||
if not key: | ||
return "key_not_fount" | ||
return key | ||
|
||
def clear_api_keys(self): | ||
for key in list(os.environ.keys()): | ||
if key.startswith(self.session_id): | ||
del os.environ[key] | ||
|
||
def __enter__(self): | ||
return self | ||
|
||
def __exit__(self, exc_type, exc_val, exc_tb): | ||
self.clear_api_keys() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .text_splitter import SimpleTextSplitter | ||
|
||
__all__ = ["SimpleTextSplitter"] |
Oops, something went wrong.