Skip to content

Commit

Permalink
Add Cohere Support
Browse files Browse the repository at this point in the history
  • Loading branch information
thomashacker committed Nov 9, 2023
1 parent c4963a6 commit f690c88
Show file tree
Hide file tree
Showing 14 changed files with 1,010 additions and 37 deletions.
758 changes: 758 additions & 0 deletions data/cohere/cohere_context.txt

Large diffs are not rendered by default.

1 change: 0 additions & 1 deletion goldenverba/components/embedding/CohereEmbedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ def __init__(self):
super().__init__()
self.name = "CohereEmbedder"
self.requires_env = ["COHERE_API_KEY"]
self.requires_library = ["openai"]
self.description = (
"Embeds and retrieves objects using Cohere's embed-multilingual-v2.0 model"
)
Expand Down
3 changes: 0 additions & 3 deletions goldenverba/components/embedding/MiniLMEmbedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,10 @@ def __init__(self):

def get_device():
if torch.cuda.is_available():
msg.info("CUDA is available. Using CUDA...")
return torch.device("cuda")
elif torch.backends.mps.is_available():
msg.info("MPS is available. Using MPS...")
return torch.device("mps")
else:
msg.info("Neither CUDA nor MPS is available. Using CPU...")
return torch.device("cpu")

self.device = get_device()
Expand Down
1 change: 0 additions & 1 deletion goldenverba/components/embedding/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,6 @@ def conversation_to_query(self, queries: list[str], conversation: dict) -> str:
for _query in queries:
query += _query + " "

print(query.lower())
return query.lower()

def retrieve_semantic_cache(
Expand Down
150 changes: 150 additions & 0 deletions goldenverba/components/generation/CohereGenerator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
import os
import asyncio
from typing import Iterator
from wasabi import msg

from goldenverba.components.generation.interface import Generator


class CohereGenerator(Generator):
"""
CohereGenerator Generator
"""

def __init__(self):
super().__init__()
self.name = "CohereGenerator"
self.description = "Generator using Cohere's command model"
self.requires_library = ["cohere"]
self.requires_env = ["COHERE_API_KEY"]
self.streamable = False
self.model_name = "command"
self.context_window = 3000

async def generate(
self,
queries: list[str],
context: list[str],
conversation: dict = {},
) -> str:
"""Generate an answer based on a list of queries and list of contexts, and includes conversational context
@parameter: queries : list[str] - List of queries
@parameter: context : list[str] - List of contexts
@parameter: conversation : dict - Conversational context
@returns str - Answer generated by the Generator
"""

message, _conversation = self.prepare_messages(queries, context, conversation)

try:
import cohere

co = cohere.Client(os.getenv("COHERE_API_KEY"))

# This is your synchronous chat function call.
def synchronous_chat_call():
# ... setup your parameters for the call ...
return co.chat(
chat_history=_conversation,
message=message,
model="command",
temperature=0.1,
)

# This is your async wrapper function.
async def asynchronous_chat_call():
chat_obj = await asyncio.to_thread(synchronous_chat_call)
return chat_obj

chat_obj = await asynchronous_chat_call()

system_msg = str(chat_obj.text)

except Exception as e:
raise e

return system_msg

async def generate_stream(
self,
queries: list[str],
context: list[str],
conversation: dict = {},
) -> Iterator[dict]:
"""Generate a stream of response dicts based on a list of queries and list of contexts, and includes conversational context
@parameter: queries : list[str] - List of queries
@parameter: context : list[str] - List of contexts
@parameter: conversation : dict - Conversational context
@returns Iterator[dict] - Token response generated by the Generator in this format {system:TOKEN, finish_reason:stop or empty}
"""

message, _conversation = self.prepare_messages(queries, context, conversation)

try:
import cohere
from cohere.responses.chat import StreamTextGeneration, StreamEnd

co = cohere.Client(os.getenv("COHERE_API_KEY"))

async for chunk in co.chat(
chat_history=_conversation,
stream=True,
message=message,
model="command",
temperature=0.1,
):
if isinstance(chunk, StreamTextGeneration):
yield {
"message": chunk.text,
"finish_reason": "",
}
elif isinstance(chunk, StreamEnd):
yield {
"message": "",
"finish_reason": "stop",
}

except Exception as e:
raise e
msg.warn(str(e))
yield {
"message": "",
"finish_reason": "stop",
}

def prepare_messages(
self, queries: list[str], context: list[str], conversation: dict[str, str]
) -> dict[str, str]:
"""
Prepares a list of messages formatted for a Retrieval Augmented Generation chatbot system, including system instructions, previous conversation, and a new user query with context.
@parameter queries: A list of strings representing the user queries to be answered.
@parameter context: A list of strings representing the context information provided for the queries.
@parameter conversation: A list of previous conversation messages that include the role and content.
@returns A list of message dictionaries formatted for the chatbot. This includes an initial system message, the previous conversation messages, and the new user query encapsulated with the provided context.
Each message in the list is a dictionary with 'role' and 'content' keys, where 'role' is either 'system' or 'user', and 'content' contains the relevant text. This will depend on the LLM used.
"""
messages = [
{
"role": "CHATBOT",
"message": f"I am a Retrieval Augmented Generation chatbot. I'll answer user queries only with their provided context. If the provided documentation does not provide enough information, I say so. If the answer requires code examples I encapsulate them with ```programming-language-name ```. I don't do pseudo-code.",
}
]

for message in conversation:
_type = ""
if message.type == "system":
_type = "CHATBOT"
else:
_type = "USER"

messages.append({"role": _type, "message": message.content})

query = " ".join(queries)
user_context = " ".join(context)

prompt = f"Please answer this query: '{query}' with this provided context: {user_context}"

return prompt, messages
24 changes: 19 additions & 5 deletions goldenverba/components/generation/GPT4Generator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import asyncio
from typing import Iterator

from goldenverba.components.generation.interface import Generator

Expand All @@ -25,7 +26,7 @@ async def generate(
context: list[str],
conversation: dict = {},
) -> str:
"""Generate an answer based on a list of queries and list of contexts, include conversational context
"""Generate an answer based on a list of queries and list of contexts, and includes conversational context
@parameter: queries : list[str] - List of queries
@parameter: context : list[str] - List of contexts
@parameter: conversation : dict - Conversational context
Expand Down Expand Up @@ -54,12 +55,12 @@ async def generate_stream(
queries: list[str],
context: list[str],
conversation: dict = {},
) -> str:
"""Generate an answer based on a list of queries and list of contexts, include conversational context
) -> Iterator[dict]:
"""Generate a stream of response dicts based on a list of queries and list of contexts, and includes conversational context
@parameter: queries : list[str] - List of queries
@parameter: context : list[str] - List of contexts
@parameter: conversation : dict - Conversational context
@returns str - Answer generated by the Generator
@returns Iterator[dict] - Token response generated by the Generator in this format {system:TOKEN, finish_reason:stop or empty}
"""

messages = self.prepare_messages(queries, context, conversation)
Expand Down Expand Up @@ -92,7 +93,20 @@ async def generate_stream(
except Exception as e:
raise e

def prepare_messages(self, queries, context, conversation):
def prepare_messages(
self, queries: list[str], context: list[str], conversation: dict[str, str]
) -> dict[str, str]:
"""
Prepares a list of messages formatted for a Retrieval Augmented Generation chatbot system, including system instructions, previous conversation, and a new user query with context.
@parameter queries: A list of strings representing the user queries to be answered.
@parameter context: A list of strings representing the context information provided for the queries.
@parameter conversation: A list of previous conversation messages that include the role and content.
@returns A list of message dictionaries formatted for the chatbot. This includes an initial system message, the previous conversation messages, and the new user query encapsulated with the provided context.
Each message in the list is a dictionary with 'role' and 'content' keys, where 'role' is either 'system' or 'user', and 'content' contains the relevant text. This will depend on the LLM used.
"""
messages = [
{
"role": "system",
Expand Down
28 changes: 19 additions & 9 deletions goldenverba/components/generation/Llama2Generator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from goldenverba.components.generation.interface import Generator
from wasabi import msg
import asyncio
import os
import asyncio
from wasabi import msg
from typing import Iterator

from goldenverba.components.generation.interface import Generator


class Llama2Generator(Generator):
Expand All @@ -27,13 +29,10 @@ def __init__(self):

def get_device():
if torch.cuda.is_available():
msg.info("CUDA is available. Using CUDA...")
return torch.device("cuda")
elif torch.backends.mps.is_available():
msg.info("MPS is available. Using MPS...")
return torch.device("mps")
else:
msg.info("Neither CUDA nor MPS is available. Using CPU...")
return torch.device("cpu")

self.device = get_device()
Expand All @@ -56,12 +55,12 @@ async def generate_stream(
queries: list[str],
context: list[str],
conversation: dict = {},
) -> str:
"""Generate an answer based on a list of queries and list of contexts, include conversational context
) -> Iterator[dict]:
"""Generate a stream of response dicts based on a list of queries and list of contexts, and includes conversational context
@parameter: queries : list[str] - List of queries
@parameter: context : list[str] - List of contexts
@parameter: conversation : dict - Conversational context
@returns str - Answer generated by the Generator
@returns Iterator[dict] - Token response generated by the Generator in this format {system:TOKEN, finish_reason:stop or empty}
"""

messages = self.prepare_messages(queries, context, conversation)
Expand Down Expand Up @@ -151,6 +150,17 @@ async def generate_stream(
raise e

def prepare_messages(self, queries, context, conversation):
"""
Prepares a list of messages formatted for a Retrieval Augmented Generation chatbot system, including system instructions, previous conversation, and a new user query with context.
@parameter queries: A list of strings representing the user queries to be answered.
@parameter context: A list of strings representing the context information provided for the queries.
@parameter conversation: A list of previous conversation messages that include the role and content.
@returns A list or of message dictionaries or whole prompts formatted for the chatbot. This includes an initial system message, the previous conversation messages, and the new user query encapsulated with the provided context.
Each message in the list is a dictionary with 'role' and 'content' keys, where 'role' is either 'system' or 'user', and 'content' contains the relevant text. This will depend on the LLM used.
"""
llama_prompt = f"""
<s>[INST] <<SYS>>
You are a Retrieval Augmented Generation chatbot. Answer user queries using only the provided context. If the context does not provide enough information, say so. If the answer requires code examples encapsulate them with ```programming-language-name ```. Don't do pseudo-code. \n<</SYS>>\n\n
Expand Down
28 changes: 24 additions & 4 deletions goldenverba/components/generation/interface.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Iterator

from goldenverba.components.component import VerbaComponent


Expand All @@ -17,7 +19,7 @@ async def generate(
context: list[str],
conversation: dict = {},
) -> str:
"""Generate an answer based on a list of queries and list of contexts, include conversational context
"""Generate an answer based on a list of queries and list of contexts, and includes conversational context
@parameter: queries : list[str] - List of queries
@parameter: context : list[str] - List of contexts
@parameter: conversation : dict - Conversational context
Expand All @@ -30,13 +32,31 @@ async def generate_stream(
queries: list[str],
context: list[str],
conversation: dict = {},
) -> str:
"""Generate an answer based on a list of queries and list of contexts, include conversational context
) -> Iterator[dict]:
"""Generate a stream of response dicts based on a list of queries and list of contexts, and includes conversational context
@parameter: queries : list[str] - List of queries
@parameter: context : list[str] - List of contexts
@parameter: conversation : dict - Conversational context
@returns str - Answer generated by the Generator
@returns Iterator[dict] - Token response generated by the Generator in this format {system:TOKEN, finish_reason:stop or empty}
"""
raise NotImplementedError(
"generate_stream method must be implemented by a subclass."
)

def prepare_messages(
self, queries: list[str], context: list[str], conversation: dict[str, str]
) -> any:
"""
Prepares a list of messages formatted for a Retrieval Augmented Generation chatbot system, including system instructions, previous conversation, and a new user query with context.
@parameter queries: A list of strings representing the user queries to be answered.
@parameter context: A list of strings representing the context information provided for the queries.
@parameter conversation: A list of previous conversation messages that include the role and content.
@returns A list or of message dictionaries or whole prompts formatted for the chatbot. This includes an initial system message, the previous conversation messages, and the new user query encapsulated with the provided context.
Each message in the list is a dictionary with 'role' and 'content' keys, where 'role' is either 'system' or 'user', and 'content' contains the relevant text. This will depend on the LLM used.
"""
raise NotImplementedError(
"prepare_messages method must be implemented by a subclass."
)
Loading

0 comments on commit f690c88

Please sign in to comment.