Skip to content

Commit

Permalink
add rag system
Browse files Browse the repository at this point in the history
  • Loading branch information
schartz committed Oct 8, 2024
1 parent 7ff9865 commit c6a6422
Show file tree
Hide file tree
Showing 29 changed files with 572 additions and 26 deletions.
5 changes: 5 additions & 0 deletions rag_system/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .framework import RAGAFramework
from .document import RAGADocument as RAGADocument

__all__ = ["RAGAFramework", "RAGADocument"]
__version__ = "0.1.0"
10 changes: 10 additions & 0 deletions rag_system/document.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from typing import Dict, Optional


class RAGADocument:
def __init__(self, content: str, metadata: Optional[Dict] = None):
self.content = content
self.metadata = metadata or {}

def __repr__(self):
return f"Document(content={self.content[:50]}..., metadata={self.metadata})"
17 changes: 17 additions & 0 deletions rag_system/embeddings/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from .base import BaseEmbedding
from .openai_embeddings import OpenAIEmbedding
from .ollama_embeddngs import OllamaEmbeddings
from .le_mistral_embeddings import LeMistralEmbeddings


def get_embedding_model(
provider: str, api_key: str, model: str = "ollama"
) -> BaseEmbedding:
if provider == "ollama":
return OllamaEmbeddings(api_key, model)
elif provider == "mistral":
return LeMistralEmbeddings(api_key, model)
elif provider == "openai":
return OpenAIEmbedding(api_key, model)
else:
raise ValueError(f"Unsupported embedding provider: {provider}")
11 changes: 11 additions & 0 deletions rag_system/embeddings/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from abc import ABC, abstractmethod
from typing import List

class BaseEmbedding(ABC):
@abstractmethod
def embed(self, text: str) -> List[float]:
pass

@abstractmethod
def embed_batch(self, texts: List[str]) -> List[List[float]]:
pass
37 changes: 37 additions & 0 deletions rag_system/embeddings/le_mistral_embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from mistralai.models.embeddingresponse import EmbeddingResponse
from mistralai import Mistral
from typing import List, Union
from .base import BaseEmbedding


class LeMistralEmbeddings(BaseEmbedding):
def __init__(self, api_key: str = "xyz", model_name: str = "nomic-embed-text"):
self.api_key = api_key
self.model = model_name
self.mistral_client = Mistral(api_key=self.api_key)

def embed(self, text: str) -> List[float]:
response: Union[EmbeddingResponse, None] = (
self.mistral_client.embeddings.create(model=self.model, inputs=[text])
)
if not response:
return []

if not response.data[0].embedding:
return []
return response.data[0].embedding

def embed_batch(self, texts: List[str]) -> List[List[float]]:
response: Union[EmbeddingResponse, None] = (
self.mistral_client.embeddings.create(model=self.model, inputs=texts)
)
if not response:
return []

_embeddings = []
for item in response.data:
if item.embedding is not None:
_embeddings.append(item.embedding)
else:
_embeddings.append([])
return _embeddings
17 changes: 17 additions & 0 deletions rag_system/embeddings/ollama_embeddngs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import ollama
from typing import List
from .base import BaseEmbedding


class OllamaEmbeddings(BaseEmbedding):
def __init__(self, api_key: str = "xyz", model_name: str = "nomic-embed-text"):
self.api_key = api_key
self.model = model_name

def embed(self, text: str) -> List[float]:
response = ollama.embed(model=self.model, input=[text])
return response["embeddings"][0]

def embed_batch(self, texts: List[str]) -> List[List[float]]:
response = ollama.embed(model=self.model, input=texts)
return response["embeddings"]
18 changes: 18 additions & 0 deletions rag_system/embeddings/openai_embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import openai
from typing import List
from .base import BaseEmbedding


class OpenAIEmbedding(BaseEmbedding):
def __init__(self, api_key: str, model: str = "text-embedding-ada-002"):
self.api_key = api_key
self.model = model
openai.api_key = self.api_key

def embed(self, text: str) -> List[float]:
response = openai.Embedding.create(input=[text], model=self.model)
return response["data"][0]["embedding"]

def embed_batch(self, texts: List[str]) -> List[List[float]]:
response = openai.Embedding.create(input=texts, model=self.model)
return [item["embedding"] for item in response["data"]]
102 changes: 102 additions & 0 deletions rag_system/framework.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
from typing import List, Literal, Optional, Dict
from .document import RAGADocument
from .embeddings import get_embedding_model
from .vectorstores import get_vectorstore
from .llms import get_llm
from .utils.text_splitter import SimpleTextSplitter
from .retrievers.similarity import SimilarityRetriever
from .session import Session


EmbeddingTypes = Literal["ollama", "mistral", "openai"]
LLMTypes = Literal["ollama", "mistral", "openai"]
VectorDBTypes = Literal["chroma", "faiss"]
# LLMModelNames = Literal["llama3.1", "mistral", "mistral-meno", "gemma2", "gemeni-pro"]


class RAGAFramework:
def __init__(
self,
embedding_provider: EmbeddingTypes,
vectorstore_provider: VectorDBTypes,
llm_provider: LLMTypes,
api_keys: Dict[str, str],
embedding_model: str = "nomic-embed-text",
llm_model: str = "mistral",
system_prompt: Optional[str] = None,
text_splitter: Optional[SimpleTextSplitter] = None,
):
"""
RAGAFramework class - A multi-purpose AI assistant framework that utilizes embedding, vectorstore, and LLM services in a unified manner.
Provides methods to add documents, query the model, and customize the system prompt. The framework uses context from the stored documents to generate responses.
Usage example:
```
from RAGAFramework import RAGAFramework
# Initialize the RAGAFramework instance with certain configurations
framework = RAGAFramework(embedding_provider="HuggingFace", vectorstore_provider="DrQA", llm_provider="RiversideCode")
# Add documents to the framework
# Assuming documents is a list of Document objects as per the SDK's definition
framework.add_documents(documents)
# Set a custom system prompt (optional, defaults to the default one)
framework.system_prompt = "You are an assistant that helps with specific tasks. Please perform the task asked by the user."
# Query the framework with a given question
response = framework.query("What is the capital of France?")
# Print the generated response from the query
print(response)
```
"""
self.session = Session()
self.session.set_api_keys(api_keys)

self.embedding = get_embedding_model(
embedding_provider,
self.session.get_api_key(embedding_provider),
embedding_model,
)
self.vectorstore = get_vectorstore(
vectorstore_provider, self.session.get_api_key(vectorstore_provider)
)
self.llm = get_llm(
llm_provider, self.session.get_api_key(llm_provider), llm_model
)

self.system_prompt = system_prompt or self._default_system_prompt()
self.text_splitter = text_splitter or SimpleTextSplitter()
self.retriever = SimilarityRetriever(self.vectorstore, self.embedding)

def _default_system_prompt(self) -> str:
return (
"You are a helpful AI assistant. Use the provided context to answer "
"questions. If you're unsure or the answer isn't in the context, "
"say you don't know."
)

def add_documents(self, documents: List[RAGADocument]):
split_docs = []
for doc in documents:
chunks = self.text_splitter.split_text(doc.content)
split_docs.extend([RAGADocument(chunk, doc.metadata) for chunk in chunks])

embeddings = self.embedding.embed_batch([doc.content for doc in split_docs])
self.vectorstore.add_documents(split_docs, embeddings)

def query(self, query: str, k: int = 4) -> str:
retrieved_docs = self.retriever.retrieve(query, k)
context = "\n\n".join([doc.content for doc in retrieved_docs])

prompt = f"{self.system_prompt}\n\nContext:\n{context}\n\nQuestion: {query}\n\nAnswer:"

return self.llm.generate(prompt)

def __enter__(self):
return self

def __exit__(self, exc_type, exc_val, exc_tb):
self.session.clear_api_keys()
18 changes: 18 additions & 0 deletions rag_system/llms/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from .base import BaseLLM
from .openai import OpenAILLM
from .gemini import GeminiLLM
from .ollama import OllamaLLM
from .le_mistral import LeMistralLLM


def get_llm(provider: str, api_key: str, model: str = "ollama") -> BaseLLM:
if provider == "ollama":
return OllamaLLM(api_key, model)
elif provider == "mistral":
return LeMistralLLM(api_key, model)
elif provider == "openai":
return OpenAILLM(api_key, model)
elif provider == "gemini":
return GeminiLLM(api_key, model)
else:
raise ValueError(f"Unsupported LLM provider: {provider}")
6 changes: 6 additions & 0 deletions rag_system/llms/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from abc import ABC, abstractmethod

class BaseLLM(ABC):
@abstractmethod
def generate(self, prompt: str) -> str:
pass
12 changes: 12 additions & 0 deletions rag_system/llms/gemini.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import google.generativeai as genai
from .base import BaseLLM


class GeminiLLM(BaseLLM):
def __init__(self, api_key: str, model: str = "gemini-pro"):
genai.configure(api_key=api_key)
self.model = genai.GenerativeModel(model)

def generate(self, prompt: str) -> str:
response = self.model.generate_content(prompt)
return response.text
26 changes: 26 additions & 0 deletions rag_system/llms/le_mistral.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from typing import Union
from .base import BaseLLM
from mistralai import ChatCompletionResponse, Mistral


class LeMistralLLM(BaseLLM):
def __init__(self, api_key: str = "xyz", model: str = "llama3.1"):
self.api_key = api_key
self.model = model
self.mistral_client = Mistral(api_key=self.api_key)

def generate(self, prompt: str) -> str:
response: Union[None, ChatCompletionResponse] = (
self.mistral_client.chat.complete(
model=self.model,
messages=[{"role": "user", "content": prompt}],
stream=False,
)
)
if (
not response
or not response.choices
or not response.choices[0].message.content
):
return ""
return response.choices[0].message.content
16 changes: 16 additions & 0 deletions rag_system/llms/ollama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import ollama
from .base import BaseLLM


class OllamaLLM(BaseLLM):
def __init__(self, api_key: str = "xyz", model: str = "llama3.1"):
self.api_key = api_key
self.model = model

def generate(self, prompt: str) -> str:
response = ollama.chat(
model=self.model,
messages=[{"role": "user", "content": prompt}],
stream=False,
)
return response["message"]["content"]
15 changes: 15 additions & 0 deletions rag_system/llms/openai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import openai
from .base import BaseLLM

class OpenAILLM(BaseLLM):
def __init__(self, api_key: str, model: str = "gpt-3.5-turbo"):
self.api_key = api_key
self.model = model
openai.api_key = self.api_key

def generate(self, prompt: str) -> str:
response = openai.ChatCompletion.create(
model=self.model,
messages=[{"role": "user", "content": prompt}]
)
return response.choices[0].message.content
4 changes: 4 additions & 0 deletions rag_system/retrievers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .base import BaseRetriever
from .similarity import SimilarityRetriever

__all__ = ["BaseRetriever", "SimilarityRetriever"]
9 changes: 9 additions & 0 deletions rag_system/retrievers/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from abc import ABC, abstractmethod
from typing import List
from ..document import RAGADocument


class BaseRetriever(ABC):
@abstractmethod
def retrieve(self, query: str, k: int = 4) -> List[RAGADocument]:
pass
15 changes: 15 additions & 0 deletions rag_system/retrievers/similarity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from typing import List
from .base import BaseRetriever
from ..vectorstores.base import BaseVectorStore
from ..embeddings.base import BaseEmbedding
from ..document import RAGADocument


class SimilarityRetriever(BaseRetriever):
def __init__(self, vectorstore: BaseVectorStore, embedding: BaseEmbedding):
self.vectorstore = vectorstore
self.embedding = embedding

def retrieve(self, query: str, k: int = 4) -> List[RAGADocument]:
query_embedding = self.embedding.embed(query)
return self.vectorstore.similarity_search(query_embedding, k)
29 changes: 29 additions & 0 deletions rag_system/session.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import os
from typing import Dict
import uuid


class Session:
def __init__(self):
self.session_id = str(uuid.uuid4())

def set_api_keys(self, api_keys: Dict[str, str]):
for provider, key in api_keys.items():
os.environ[f"{self.session_id}_{provider.upper()}_API_KEY"] = key

def get_api_key(self, provider: str) -> str:
key = os.environ.get(f"{self.session_id}_{provider.upper()}_API_KEY")
if not key:
return "key_not_fount"
return key

def clear_api_keys(self):
for key in list(os.environ.keys()):
if key.startswith(self.session_id):
del os.environ[key]

def __enter__(self):
return self

def __exit__(self, exc_type, exc_val, exc_tb):
self.clear_api_keys()
3 changes: 3 additions & 0 deletions rag_system/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .text_splitter import SimpleTextSplitter

__all__ = ["SimpleTextSplitter"]
Loading

0 comments on commit c6a6422

Please sign in to comment.