Skip to content

Commit

Permalink
Fix agent import
Browse files Browse the repository at this point in the history
  • Loading branch information
Josh-XT committed Dec 12, 2023
1 parent 084bc6a commit cb85f38
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 43 deletions.
9 changes: 6 additions & 3 deletions agixt/db/Agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,12 +136,15 @@ def import_agent_config(agent_name, user="USER"):
.first()
)

if not agent:
print(f"Agent '{agent_name}' does not exist in the database.")
if agent:
print(f"Agent '{agent_name}' already exists in the database.")
return

# Get the provider ID based on the provider name in the config
provider_name = config["settings"]["provider"]
try:
provider_name = config["settings"]["provider"]
except:
provider_name = "gpt4free"
provider = session.query(ProviderModel).filter_by(name=provider_name).first()

if not provider:
Expand Down
45 changes: 5 additions & 40 deletions agixt/db/imports.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import os
import yaml
import json
import logging
from DBConnection import (
Expand All @@ -23,50 +22,16 @@


def import_agents(user="USER"):
session = get_session()
user_data = session.query(User).filter(User.email == user).first()
user_id = user_data.id
agent_folder = "agents"
logging.info("Importing agents...")
agents = [
f.name
for f in os.scandir(agent_folder)
for f in os.scandir("agents")
if f.is_dir() and not f.name.startswith("__")
]
existing_agents = session.query(Agent).filter(Agent.user_id == user_id).all()
existing_agent_names = [agent.name for agent in existing_agents]

for agent_name in agents:
agent = session.query(Agent).filter_by(name=agent_name, user_id=user_id).first()
if agent:
print(f"Updating agent: {agent_name}")
else:
# Get the agent config from agents/agent_name/config.json
agent_config_file = os.path.join(agent_folder, agent_name, "config.json")
if not os.path.exists(agent_config_file):
print(f"Agent '{agent_name}' config not found.")
continue
with open(agent_config_file, "r") as f:
agent_config = json.load(f)
agent_settings = agent_config.get("settings", {"provider": "gpt4free"})
provider_name = agent_settings["provider"]
provider = (
session.query(Provider).filter_by(name=provider_name).one_or_none()
)
if not provider:
print(f"Provider '{provider_name}' not found.")
continue
agent = Agent(
name=agent_name,
user_id=user_id,
provider_id=provider.id,
)
session.add(agent)
session.flush() # Save the agent object to generate an ID
existing_agent_names.append(agent_name)
print(f"Adding agent: {agent_name}")
session.commit()
import_agent_config(agent_name)
session.commit()
logging.info(f"Importing agent: {agent_name}")
import_agent_config(agent_name=agent_name, user=user)
logging.info("Agent import complete.")


def import_extensions():
Expand Down

0 comments on commit cb85f38

Please sign in to comment.