Skip to content

Commit

Permalink
added interface for web_search
Browse files Browse the repository at this point in the history
  • Loading branch information
nalaso committed Mar 29, 2024
1 parent 2aaac6a commit 65c5bfd
Show file tree
Hide file tree
Showing 11 changed files with 151 additions and 13 deletions.
5 changes: 4 additions & 1 deletion devika.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,15 @@ def execute_agent():
prompt = data.get("prompt")
base_model = data.get("base_model")
project_name = data.get("project_name")
web_search = None
if(data.get("web_search")):
web_search = data.get("web_search")

if not base_model:
return jsonify({"error": "base_model is required"})

thread = Thread(
target=lambda: Agent(base_model=base_model).execute(prompt, project_name)
target=lambda: Agent(base_model=base_model).execute(prompt, project_name, web_search)
)
thread.start()

Expand Down
4 changes: 3 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,6 @@ GitPython
netlify-py
Markdown
xhtml2pdf
groq
groq
google-generativeai
duckduckgo-search
5 changes: 5 additions & 0 deletions sample.config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,21 @@ PDFS_DIR = "pdfs"
PROJECTS_DIR = "projects"
LOGS_DIR = "logs"
REPOS_DIR = "repos"
WEB_SEARCH = "ddgs"

[API_KEYS]
BING = "<YOUR_BING_API_KEY>"
GOOGLE_SEARCH = "<YOUR_GOOGLE_SEARCH_API_KEY>"
GOOGLE_SEARCH_ENGINE_ID = "<YOUR_GOOGLE_SEARCH_ENGINE_ID>"
CLAUDE = "<YOUR_CLAUDE_API_KEY>"
NETLIFY = "<YOUR_NETLIFY_API_KEY>"
OPENAI = "<YOUR_OPENAI_API_KEY>"
GROQ = "<YOUR_GROQ_API_KEY>"
GEMINI = "<YOUR_GEMINI_API_KEY>"

[API_ENDPOINTS]
BING = "https://api.bing.microsoft.com/v7.0/search"
GOOGLE_SEARCH = "https://www.googleapis.com/customsearch/v1"
OLLAMA = "http://127.0.0.1:11434"

[LOGGING]
Expand Down
2 changes: 1 addition & 1 deletion setup.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@ pip3 install -r requirements.txt
playwright install
python3 -m playwright install-deps
cd ui/
npm install
bun install
30 changes: 23 additions & 7 deletions src/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@
from .decision import Decision

from src.logger import Logger
from src.config import Config
from src.project import ProjectManager
from src.state import AgentState

from src.bert.sentence import SentenceBert
from src.memory import KnowledgeBase
from src.browser.search import BingSearch
from src.browser.search import BingSearch,DuckDuckGoSearch,GoogleSearch
from src.browser import Browser
from src.browser import start_interaction
from src.filesystem import ReadCode
Expand Down Expand Up @@ -59,11 +60,26 @@ def __init__(self, base_model: str):

self.tokenizer = tiktoken.get_encoding("cl100k_base")

def search_queries(self, queries: list, project_name: str) -> dict:
def search_queries(self, queries: list, project_name: str, requested_web_search: str) -> dict:
results = {}

knowledge_base = KnowledgeBase()
bing_search = BingSearch()

web_search = None
web_search_type = Config().get_web_search()
if requested_web_search:
web_search_type = requested_web_search

if web_search_type == "bing":
web_search = BingSearch()
elif web_search_type == "google":
web_search = GoogleSearch()
elif web_search_type == "ddgs":
web_search = DuckDuckGoSearch()
else:
web_search = DuckDuckGoSearch()

self.logger.info(web_search_type)
browser = Browser()

for query in queries:
Expand All @@ -80,8 +96,8 @@ def search_queries(self, queries: list, project_name: str) -> dict:
"""
Search for the query and get the first link
"""
bing_search.search(query)
link = bing_search.get_first_link()
web_search.search(query)
link = web_search.get_first_link()

"""
Browse to the link and take a screenshot, then extract the text
Expand Down Expand Up @@ -259,7 +275,7 @@ def subsequent_execute(self, prompt: str, project_name: str) -> str:
"""
Agentic flow of execution
"""
def execute(self, prompt: str, project_name_from_user: str = None) -> str:
def execute(self, prompt: str, project_name_from_user: str = None, web_search: str = None) -> str:
if project_name_from_user:
ProjectManager().add_message_from_user(project_name_from_user, prompt)

Expand Down Expand Up @@ -332,7 +348,7 @@ def execute(self, prompt: str, project_name_from_user: str = None) -> str:

AgentState().set_agent_active(project_name, True)

search_results = self.search_queries(queries, project_name)
search_results = self.search_queries(queries, project_name, web_search)

print(json.dumps(search_results, indent=4))
print("=====" * 10)
Expand Down
40 changes: 40 additions & 0 deletions src/browser/search.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import requests
from src.config import Config
from duckduckgo_search import DDGS

class BingSearch:
def __init__(self):
Expand All @@ -22,4 +23,43 @@ def search(self, query):

def get_first_link(self):
return self.query_result["webPages"]["value"][0]["url"]

class GoogleSearch:
def __init__(self):
self.config = Config()
self.google_search_api_key = self.config.get_google_search_api_key()
self.google_search_engine_ID = self.config.get_google_search_engine_id()
self.google_search_api_endpoint = self.config.get_google_search_api_endpoint()
self.query_result = None

def search(self, query):
try:
params = {
'q': query,
'key': self.google_search_api_key,
'cx': self.google_search_engine_ID
}
response = requests.get(self.google_search_api_endpoint, params=params)
self.query_result = response.json()
except Exception as err:
return err

def get_first_link(self):
item = ""
if 'items' in self.query_result:
item = self.query_result['items'][0]['link']
return item

class DuckDuckGoSearch:
def __init__(self):
self.query_result = None

def search(self, query):
try:
self.query_result = DDGS().text(query, max_results=5)
return self.query_result
except Exception as err:
print(err)

def get_first_link(self):
return self.query_result[0]["href"]
37 changes: 36 additions & 1 deletion src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,15 @@ def get_bing_api_key(self):

def get_bing_api_endpoint(self):
return environ.get("BING_API_ENDPOINT", self.config["API_ENDPOINTS"]["BING"])

def get_google_search_api_key(self):
return environ.get("GOOGLE_SEARCH_API_KEY", self.config["API_KEYS"]["GOOGLE_SEARCH"])

def get_google_search_engine_id(self):
return environ.get("GOOGLE_SEARCH_ENGINE_ID", self.config["API_KEYS"]["GOOGLE_SEARCH_ENGINE_ID"])

def get_google_search_api_endpoint(self):
return environ.get("GOOGLE_SEARCH_API_ENDPOINT", self.config["API_ENDPOINTS"]["GOOGLE_SEARCH"])

def get_ollama_api_endpoint(self):
return environ.get(
Expand All @@ -33,7 +42,10 @@ def get_claude_api_key(self):

def get_openai_api_key(self):
return environ.get("OPENAI_API_KEY", self.config["API_KEYS"]["OPENAI"])


def get_gemini_api_key(self):
return environ.get("GEMINI_API_KEY", self.config["API_KEYS"]["GEMINI"])

def get_netlify_api_key(self):
return environ.get("NETLIFY_API_KEY", self.config["API_KEYS"]["NETLIFY"])

Expand All @@ -58,6 +70,9 @@ def get_logs_dir(self):
def get_repos_dir(self):
return environ.get("REPOS_DIR", self.config["STORAGE"]["REPOS_DIR"])

def get_web_search(self):
return environ.get("WEB_SEARCH", self.config["STORAGE"]["WEB_SEARCH"])

def get_logging_rest_api(self):
return self.config["LOGGING"]["LOG_REST_API"] == "true"

Expand All @@ -72,6 +87,18 @@ def set_bing_api_endpoint(self, endpoint):
self.config["API_ENDPOINTS"]["BING"] = endpoint
self.save_config()

def set_google_search_api_key(self, key):
self.config["API_KEYS"]["GOOGLE_SEARCH"] = key
self.save_config()

def set_google_search_engine_id(self, key):
self.config["API_KEYS"]["GOOGLE_SEARCH_ENGINE_ID"] = key
self.save_config()

def set_google_search_api_endpoint(self, endpoint):
self.config["API_ENDPOINTS"]["GOOGLE_SEARCH"] = endpoint
self.save_config()

def set_ollama_api_endpoint(self, endpoint):
self.config["API_ENDPOINTS"]["OLLAMA"] = endpoint
self.save_config()
Expand All @@ -84,6 +111,10 @@ def set_openai_api_key(self, key):
self.config["API_KEYS"]["OPENAI"] = key
self.save_config()

def set_openai_api_key(self, key):
self.config["API_KEYS"]["GEMINI"] = key
self.save_config()

def set_netlify_api_key(self, key):
self.config["API_KEYS"]["NETLIFY"] = key
self.save_config()
Expand Down Expand Up @@ -120,6 +151,10 @@ def set_logging_prompts(self, value):
self.config["LOGGING"]["LOG_PROMPTS"] = "true" if value else "false"
self.save_config()

def set_web_search(self, value):
self.config["STORAGE"]["WEB_SEARCH"] = value
self.save_config()

def save_config(self):
with open("config.toml", "w") as f:
toml.dump(self.config, f)
20 changes: 19 additions & 1 deletion src/init.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,23 @@
import os
import sys

from src.config import Config
from src.logger import Logger

def init_cmd():
config = Config()
logger = Logger()
command_line_args = sys.argv[1:]
if '--websearch' in command_line_args:
index = command_line_args.index('--websearch')
websearch_value = command_line_args[index + 1]
if(websearch_value == 'bing' or websearch_value == 'google' or websearch_value == 'ddgs'):
config.set_web_search(websearch_value)
else:
return logger.error(f"Invalid websearch value parameter: {websearch_value}")
else:
logger.info("No --websearch argument provided. Using default duckduckgo search.")

def init_devika():
config = Config()
logger = Logger()
Expand All @@ -20,7 +35,10 @@ def init_devika():
os.makedirs(pdfs_dir, exist_ok=True)
os.makedirs(projects_dir, exist_ok=True)
os.makedirs(logs_dir, exist_ok=True)


init_cmd()
logger.info(f"Using {config.get_web_search()} as default if not specified in the request.")

from src.bert.sentence import SentenceBert

logger.info("Loading sentence-transformer BERT models...")
Expand Down
14 changes: 14 additions & 0 deletions src/llm/gemini_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import google.generativeai as genai

from src.config import Config

class Gemini:
def __init__(self):
config = Config()
api_key = config.get_gemini_api_key()
genai.configure(api_key=api_key)

def inference(self, model_id: str, prompt: str) -> str:
model = genai.GenerativeModel(model_id)
response = model.generate_content(prompt)
return response.text
5 changes: 5 additions & 0 deletions src/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .ollama_client import Ollama
from .claude_client import Claude
from .openai_client import OpenAI
from .gemini_client import Gemini
from .groq_client import Groq

from src.state import AgentState
Expand All @@ -22,6 +23,8 @@ class Model(Enum):
CLAUDE_3_HAIKU = ("Claude 3 Haiku", "claude-3-haiku-20240307")
GPT_4_TURBO = ("GPT-4 Turbo", "gpt-4-0125-preview")
GPT_3_5 = ("GPT-3.5", "gpt-3.5-turbo-0125")
GEMINI_1_0_PRO = ("Gemini 1.0 Pro", "gemini-1.0-pro")
GEMINI_1_5_PRO = ("Gemini 1.5 Pro", "gemini-1.5-pro")
OLLAMA_MODELS = [
(
model["name"].split(":")[0],
Expand Down Expand Up @@ -74,6 +77,8 @@ def inference(
response = OpenAI().inference(self.model_id, prompt).strip()
elif "GROQ" in str(model):
response = Groq().inference(self.model_id, prompt).strip()
elif "GEMINI" in str(model):
response = Gemini().inference(self.model_id, prompt).strip()
else:
raise ValueError(f"Model {model} not supported")

Expand Down
2 changes: 1 addition & 1 deletion src/llm/ollama_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def list_models():

def inference(self, model_id: str, prompt: str) -> str:
try:
response = ollama.generate(model=model_id, prompt=prompt.strip())
response = client.generate(model=model_id, prompt=prompt.strip())
return response['response']
except Exception as e:
logger.error(f"Error during model inference: {e}")
Expand Down

0 comments on commit 65c5bfd

Please sign in to comment.