Skip to content

Commit

Permalink
Add support for more llm options
Browse files Browse the repository at this point in the history
  • Loading branch information
emirsahin1 committed Jan 4, 2025
1 parent 4672d19 commit 7d5cc8c
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 33 deletions.
50 changes: 29 additions & 21 deletions llm_axe/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class Agent:
Basic agent that can use premade or custom system prompts.
Custom system prompt will override any premade prompts.
"""
def __init__(self, llm:object, agent_type:AgentType=None, additional_system_instructions:str="", custom_system_prompt:str=None, format:str="", temperature:float=0.8):
def __init__(self, llm:object, agent_type:AgentType=None, additional_system_instructions:str="", custom_system_prompt:str=None, format:str="", temperature:float=0.8, **llm_options):
"""
Args:
llm (object): An LLM object with an ask function.
Expand All @@ -33,6 +33,7 @@ def __init__(self, llm:object, agent_type:AgentType=None, additional_system_inst
self.system_prompt = make_prompt("system", self.system_prompt.format(additional_instructions=additional_system_instructions))
self.temperature = temperature
self.format = format
self.llm_options = llm_options

def get_prompt(self, question):
"""
Expand Down Expand Up @@ -61,7 +62,7 @@ def ask(self, prompt, images:list=None, history:list=None):
prompts.extend(history)

prompts.append(make_prompt("user", prompt, images))
response = self.llm.ask(prompts, temperature=self.temperature, format=self.format)
response = self.llm.ask(prompts, temperature=self.temperature, format=self.format, **self.llm_options)

self.chat_history.append(prompts[-1])
self.chat_history.append(make_prompt("assistant", response))
Expand All @@ -73,7 +74,7 @@ class ObjectDetectorAgent():
An ObjectDetectorAgent agent is used to detect objects in an image.
Requires a multimodal LLM.
"""
def __init__(self, vision_llm:object, text_llm:object, vision_temperature:float=0.3, text_temperature:float=0.3):
def __init__(self, vision_llm:object, text_llm:object, vision_temperature:float=0.3, text_temperature:float=0.3, **llm_options):
"""
Initializes a new ObjectDetectorAgent object.
Expand All @@ -88,6 +89,7 @@ def __init__(self, vision_llm:object, text_llm:object, vision_temperature:float=
self.__system_prompt = make_prompt("system", get_yaml_prompt("system_prompts.yaml", "ObjectDetector"))
self.vision_temperature = vision_temperature
self.text_temperature = text_temperature
self.llm_options = llm_options

def detect(self, images:list, objects:list=None, detection_criteria:str=None):
"""
Expand All @@ -103,9 +105,9 @@ def detect(self, images:list, objects:list=None, detection_criteria:str=None):
return None

prompts = make_prompt("user", "Detect all objects in this image", images=images)
detected_objects = self.vision_llm.ask([self.__system_prompt, prompts], temperature=self.vision_temperature)
detected_objects = self.vision_llm.ask([self.__system_prompt, prompts], temperature=self.vision_temperature, **self.llm_options)
prompts = self.__get_prompt(detected_objects, objects, detection_criteria)
response = self.text_llm.ask(prompts, format="json", temperature=self.text_temperature)
response = self.text_llm.ask(prompts, format="json", temperature=self.text_temperature, **self.llm_options)

return response

Expand Down Expand Up @@ -141,7 +143,7 @@ class PythonAgent():
It will provide code to execute and the imports used in the code.
IMPORTANT!!: Code should ALWAYS be executed in a virtual or isolated environment.
"""
def __init__(self, llm:object, temperature:float=0.8):
def __init__(self, llm:object, temperature:float=0.8, **llm_options):
"""
Initializes a new PythonAgent object.
Expand All @@ -154,6 +156,7 @@ def __init__(self, llm:object, temperature:float=0.8):
self.system_prompt = make_prompt("system", get_yaml_prompt("system_prompts.yaml", "PythonAgent"))
self.library_extractor_prompt = make_prompt("system", get_yaml_prompt("system_prompts.yaml", "ImportExtractor"))
self.temperature = temperature
self.llm_options = llm_options

def ask(self, prompt, history:list=None):
"""
Expand All @@ -177,7 +180,7 @@ def ask(self, prompt, history:list=None):
user_prompt = make_prompt("user", prompt)
coder_prompts.append(user_prompt)

code_response = self.llm.ask(coder_prompts, temperature=self.temperature)
code_response = self.llm.ask(coder_prompts, temperature=self.temperature, **self.llm_options)
self.chat_history.append(user_prompt)
self.chat_history.append(make_prompt("assistant", code_response))

Expand All @@ -188,7 +191,7 @@ def ask(self, prompt, history:list=None):
code = code.replace("Python", "")

# Extract imports
imports = self.llm.ask([self.library_extractor_prompt, make_prompt("user", code_response)], format="json", temperature=self.temperature)
imports = self.llm.ask([self.library_extractor_prompt, make_prompt("user", code_response)], format="json", temperature=self.temperature, **self.llm_options)
self.chat_history.append(make_prompt("assistant", imports))
imports = safe_read_json(imports)

Expand All @@ -199,7 +202,7 @@ class DataExtractor():
"""
A DataExtractor agent is used to extract information from given content.
"""
def __init__(self, llm:object, reply_as_json:bool=False, additional_system_instructions:str="", temperature:float=0.8):
def __init__(self, llm:object, reply_as_json:bool=False, additional_system_instructions:str="", temperature:float=0.8, **llm_options):
"""
Initializes a new DataExtractor.
Args:
Expand All @@ -216,6 +219,7 @@ def __init__(self, llm:object, reply_as_json:bool=False, additional_system_instr
self.llm = llm
self.chat_history = []
self.temperature = temperature
self.llm_options = llm_options

def get_prompt(self, info:str, data_points:list=[]):
"""
Expand Down Expand Up @@ -243,7 +247,7 @@ def ask(self, info:str, data_points:list=[]):
Example: ["name", "age", "city"] Defaults to None.
"""
prompts = self.get_prompt(info, data_points)
resp = self.llm.ask(self.get_prompt(info, data_points), temperature=self.temperature)
resp = self.llm.ask(self.get_prompt(info, data_points), temperature=self.temperature, **self.llm_options)
self.chat_history.append(prompts[1])
self.chat_history.append(make_prompt("assistant", resp))
return resp
Expand All @@ -254,7 +258,7 @@ class PdfReader():
An Agent used to answer questions based on information from given PDF files.
"""

def __init__(self, llm:object, additional_system_instructions:str="", custom_system_prompt:str=None, temperature:float=0.8):
def __init__(self, llm:object, additional_system_instructions:str="", custom_system_prompt:str=None, temperature:float=0.8, **llm_options):
"""
Initializes a new PdfReader.
Args:
Expand All @@ -269,6 +273,7 @@ def __init__(self, llm:object, additional_system_instructions:str="", custom_sys
self.system_prompt = get_yaml_prompt("system_prompts.yaml", "DocumentReader")
self.custom_system_prompt = custom_system_prompt
self.temperature = temperature
self.llm_options = llm_options


def ask(self, question:str, pdf_files:list, history:list=None):
Expand All @@ -290,7 +295,7 @@ def ask(self, question:str, pdf_files:list, history:list=None):
prompts.extend(history)

prompts.append(question_prompts[1])
response = self.llm.ask(prompts, temperature=self.temperature)
response = self.llm.ask(prompts, temperature=self.temperature, **self.llm_options)

self.chat_history.append(question_prompts[1]) # dont include the system prompt
self.chat_history.append(make_prompt("assistant", response))
Expand All @@ -315,7 +320,7 @@ def get_prompt(self, question, pdf_files:list=None):

self.system_prompt = make_prompt("system", self.system_prompt.format(additional_instructions=self.additional_instructions))

user_prompt = make_prompt("user", pdf_text + "\n" + question)
user_prompt = make_prompt("user", pdf_text + "\nUser's question: " + question)
prompts = [self.system_prompt, user_prompt]

return prompts
Expand All @@ -329,7 +334,7 @@ class FunctionCaller():
have type annotations doc string descriptions.
"""

def __init__(self, llm:object, functions:list, additional_system_instructions:str="", custom_system_prompt:str=None, temperature:float=0.8):
def __init__(self, llm:object, functions:list, additional_system_instructions:str="", custom_system_prompt:str=None, temperature:float=0.8, **llm_options):
"""
Initializes a new Function Caller.
Expand All @@ -346,6 +351,7 @@ def __init__(self, llm:object, functions:list, additional_system_instructions:st
self.additional_instructions = additional_system_instructions
self.schema = generate_schema(functions)
self.temperature = temperature
self.llm_options = llm_options

if custom_system_prompt is None:
self.system_prompt = get_yaml_prompt("system_prompts.yaml", "FunctionCaller")
Expand Down Expand Up @@ -388,7 +394,7 @@ def get_function(self, question, history:list=None):
prompts.extend(history)
prompts.append(question_prompts[1])

response = self.llm.ask(prompts, format="json", temperature=self.temperature)
response = self.llm.ask(prompts, format="json", temperature=self.temperature, **self.llm_options)
response_json = safe_read_json(response)

self.chat_history.append(question_prompts[1])
Expand Down Expand Up @@ -442,7 +448,7 @@ class WebsiteReaderAgent:
An agent that will read a specificwebsite and answer questions based on it.
"""

def __init__(self, llm:object, additional_system_instructions:str="", custom_site_reader:callable=None, temperature:float=0.8):
def __init__(self, llm:object, additional_system_instructions:str="", custom_site_reader:callable=None, temperature:float=0.8, **llm_options):
"""
Args:
llm (object): An LLM object. Must have an ask method.
Expand All @@ -456,6 +462,7 @@ def __init__(self, llm:object, additional_system_instructions:str="", custom_sit
self.additional_system_instructions = additional_system_instructions
self.read_function = custom_site_reader if custom_site_reader else read_website
self.temperature = temperature
self.llm_options = llm_options

def ask(self, question:str, url:str, history:list=None):
"""
Expand Down Expand Up @@ -485,7 +492,7 @@ def ask(self, question:str, url:str, history:list=None):
user_prompt = make_prompt("user", question)
prompts.append(user_prompt)

response = self.llm.ask(prompts, temperature=self.temperature)
response = self.llm.ask(prompts, temperature=self.temperature, **self.llm_options)
self.chat_history.append(user_prompt)
self.chat_history.append(make_prompt("assistant", response))
return response
Expand All @@ -497,7 +504,7 @@ class OnlineAgent:
It will use the internet to try and best answer the user prompt.
"""

def __init__(self, llm:object, additional_system_instructions:str="", custom_searcher:callable=None, custom_site_reader:callable=None, temperature:float=0.8):
def __init__(self, llm:object, additional_system_instructions:str="", custom_searcher:callable=None, custom_site_reader:callable=None, temperature:float=0.8, **llm_options):
"""
Args:
llm (object): An LLM object. Must have an ask method.
Expand All @@ -513,6 +520,7 @@ def __init__(self, llm:object, additional_system_instructions:str="", custom_sea
self.search_function = custom_searcher if custom_searcher else internet_search
self.site_reader_function = custom_site_reader if custom_site_reader else read_website
self.temperature = temperature
self.llm_options = llm_options

def search(self, prompt, history:list=None):
"""
Expand Down Expand Up @@ -546,7 +554,7 @@ def search(self, prompt, history:list=None):
url_picker_prompt = make_prompt("user", url_picker_prompt.format(question=prompt, urls=search_results))
url_picker_prompts.append(url_picker_prompt)

resp = self.llm.ask(url_picker_prompts, format="json", temperature=self.temperature)
resp = self.llm.ask(url_picker_prompts, format="json", temperature=self.temperature, **self.llm_options)
resp_json = safe_read_json(resp)

self.chat_history.append(url_picker_prompt)
Expand All @@ -573,7 +581,7 @@ def search(self, prompt, history:list=None):
Start your answer with "Based on information from the internet, "
'''

final_responder = Agent(llm=self.llm, agent_type=AgentType.GENERIC_RESPONDER)
final_responder = Agent(llm=self.llm, agent_type=AgentType.GENERIC_RESPONDER, temperature=self.temperature, **self.llm_options)
response = final_responder.ask(user_prompt, history=history)

self.chat_history.append(make_prompt("user", user_prompt))
Expand All @@ -584,7 +592,7 @@ def search(self, prompt, history:list=None):
def get_search_query(self, question):
user_prompt = make_prompt("user", question)
prompts = [self.system_prompt, user_prompt]
response = self.llm.ask(prompts, format="json", temperature=self.temperature)
response = self.llm.ask(prompts, format="json", temperature=self.temperature, **self.llm_options)
response_json = safe_read_json(response)
self.chat_history.append(prompts[1])
self.chat_history.append(make_prompt("assistant", response))
Expand Down
4 changes: 2 additions & 2 deletions llm_axe/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@ def __init__(self, host:str="http://localhost:11434", model:str=None):
self._model = model
self._ollama = Client(host)

def ask(self, prompts:list, format:str="", temperature:float=0.8):
def ask(self, prompts:list, format:str="", temperature:float=0.8, stream:bool=False, **options):
"""
Args:
prompts (list): A list of prompts to ask.
format (str, optional): The format of the response. Use "json" for json. Defaults to "".
temperature (float, optional): The temperature of the LLM. Defaults to 0.8.
"""
return self._ollama.chat(model=self._model, messages=prompts, format=format, options={"temperature": temperature})["message"]["content"]
return self._ollama.chat(model=self._model, messages=prompts, format=format, options={"temperature": temperature, **options}, stream=stream)["message"]["content"]

10 changes: 0 additions & 10 deletions tests/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ def test_ask_with_history(self):

agent = Agent(self.llm_mock, agent_type=AgentType.GENERIC_RESPONDER)
response = agent.ask(prompt, history=[{"role": "user", "content": "Hello"}])
print(agent.chat_history[0])
self.assertEqual(response, mock_resp)
self.assertEqual(agent.chat_history[0]["role"], "user")
self.assertEqual(agent.chat_history[0]["content"], "What is the meaning of life?")
Expand Down Expand Up @@ -102,15 +101,6 @@ def test_ask_with_history(self):
self.assertEqual(agent.chat_history[0]["role"], "user")
self.assertEqual(agent.chat_history[1]["role"], "assistant")

def test_get_prompt(self):
prompt = "What is the meaning of life?"
agent = PdfReader(self.llm_mock)

with patch('llm_axe.agents.read_pdf') as read_pdf_mock:
read_pdf_mock.return_value = "Pdf output"
prompts = agent.get_prompt(prompt, ["pdf1.pdf"])
self.assertEqual(prompts, [agent.system_prompt, {"role": "user", "content": prompt}])


class TestFunctionCaller(unittest.TestCase):

Expand Down

0 comments on commit 7d5cc8c

Please sign in to comment.