Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/OpenBMB/XAgent
Browse files Browse the repository at this point in the history
  • Loading branch information
luyaxi committed Oct 20, 2023
2 parents c9d5b5c + 2ea7b77 commit ff52c5f
Show file tree
Hide file tree
Showing 13 changed files with 165 additions and 99 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,6 @@ local_workspace/**
!local_workspace/readme.md
sync.py
assets/private.yml
assets/garbage

XAgentWeb/src/api/backend.ts
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ You can use argument `--upload_files` to select files you want to submit to XAge
The local workspace for your XAgent is in `local_workspace`, where you can find all the files generated by XAgent throughout the running process. Besides, in `running_records` you can find all the intermediate steps information, e.g. task statuses, LLM's input-output pairs, used tools, etc.
After execution, the full `workspace` in `ToolServerNode` will be copied to `running_records` for your convenience.

If you want to load from a existing record, set `record_dir` in config,default to `Null`. All the runs will set to a record automatically, We remove `api key` and other unsafe items, so you can share your run with other people by sharing the records.

- Run XAgent with GUI
```bash
cd XAgentServer
Expand Down
3 changes: 2 additions & 1 deletion ToolServer/ToolServerNode/core/envs/filesystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,8 @@ def write_to_file(self, filepath:str,content:str,truncating:bool = False,line_nu
if not filepath.startswith(self.work_directory):
filepath = filepath.strip('/')
full_path = os.path.join(self.work_directory, filepath)

else:
full_path = filepath
if not self._is_path_within_workspace(full_path) or self._check_ignorement(full_path):
raise ValueError(f"File {filepath} is not within workspace.")

Expand Down
12 changes: 2 additions & 10 deletions XAgent/agent/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,8 @@ def _chat_completion_request(messages, functions=None,function_call=None, model=

# Yujia: maybe temperature == 0 is more stable? Not rigrously tested.
# json_data.update({"temperature": 0.1})
response = recorder.query_llm_inout(restrict_cache_query=restrict_cache_query,
messages=messages,
functions=functions,
function_call=function_call,
model = model,
stop = stop,
other_args = kwargs)

if response is None:
response = openai_chatcompletion_request(**json_data)
logger.debug("from _chat_completion_request")
response = openai_chatcompletion_request(**json_data)

return response

Expand Down
90 changes: 56 additions & 34 deletions XAgent/ai_functions/request/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ class FunctionCallSchemaError(Exception):
pass




def dynamic_json_fixs(args,function_schema,messages:list=[],error_message:str=None):
logger.typewriter_log(f'Schema Validation for Function call {function_schema["name"]} failed, trying to fix it...',Fore.YELLOW)
repair_req = deepcopy(CONFIG.default_completion_kwargs)
Expand Down Expand Up @@ -70,51 +72,71 @@ def load_args_with_schema_validation(function_schema:dict,args:str,messages:list
return arguments


LLM_query_count = 0 #motex lock

@retry(retry=retry_if_not_exception_type((AuthenticationError, PermissionError, InvalidRequestError)),stop=stop_after_attempt(CONFIG.max_retry_times+6),wait=wait_chain(*[wait_none() for _ in range(6)]+[wait_exponential(min=113, max=293)]),reraise=True)
def openai_chatcompletion_request(*,function_call_check=True,**kwargs):
@retry(retry=retry_if_not_exception_type((AuthenticationError, PermissionError, InvalidRequestError,AssertionError)),stop=stop_after_attempt(CONFIG.max_retry_times+6),wait=wait_chain(*[wait_none() for _ in range(6)]+[wait_exponential(min=113, max=293)]),reraise=True)
def openai_chatcompletion_request(*,function_call_check=True,record_query_response=None,**kwargs):
global LLM_query_count
regist_kwargs = deepcopy(kwargs)
query_kwargs = deepcopy(kwargs)
model_name = get_openai_model_name(kwargs.pop('model', 'gpt-3.5-turbo-16k'))
print("using " + model_name)
logger.debug("openai_chatcompletion_request: using " + model_name)

chatcompletion_kwargs = get_apiconfig_by_model(model_name)
chatcompletion_kwargs.update(kwargs)
chatcompletion_kwargs.pop('schema_error_retry',None)

try:
response = openai.ChatCompletion.create(**chatcompletion_kwargs)
if response['choices'][0]['finish_reason'] == 'length':
raise InvalidRequestError('maximum context length exceeded',None)
except InvalidRequestError as e:
if 'maximum context length' in e._message:
if model_name == 'gpt-4':
if 'gpt-4-32k' in CONFIG.openai_keys:
model_name = 'gpt-4-32k'
else:
llm_query_id = deepcopy(LLM_query_count) #motex lock
LLM_query_count += 1
record_query_response = recorder.query_llm_inout(llm_query_id = llm_query_id,
messages=query_kwargs.pop("messages",None),
functions=query_kwargs.pop("functions",None),
function_call=query_kwargs.pop("function_call",None),
model = query_kwargs.pop("model",None),
stop = query_kwargs.pop("stop",None),
other_args = query_kwargs)
if record_query_response != None:
response = record_query_response
else:
try:
response = openai.ChatCompletion.create(**chatcompletion_kwargs)
response = json5.loads(str(response))
if response['choices'][0]['finish_reason'] == 'length':
raise InvalidRequestError('maximum context length exceeded',None)
except InvalidRequestError as e:
logger.info(e)
if 'maximum context length' in e._message:
if model_name == 'gpt-4':
if 'gpt-4-32k' in CONFIG.openai_keys:
model_name = 'gpt-4-32k'
else:
model_name = 'gpt-3.5-turbo-16k'
elif model_name == 'gpt-3.5-turbo':
model_name = 'gpt-3.5-turbo-16k'
elif model_name == 'gpt-3.5-turbo':
model_name = 'gpt-3.5-turbo-16k'
else:
raise e
print("max context length reached, retrying with " + model_name)
chatcompletion_kwargs = get_apiconfig_by_model(model_name)
chatcompletion_kwargs.update(kwargs)
chatcompletion_kwargs.pop('schema_error_retry',None)

response = openai.ChatCompletion.create(**chatcompletion_kwargs)
response = json5.loads(str(response))
else:
raise e
print("max context length reached, retrying with " + model_name)
chatcompletion_kwargs = get_apiconfig_by_model(model_name)
chatcompletion_kwargs.update(kwargs)
chatcompletion_kwargs.pop('schema_error_retry',None)

response = openai.ChatCompletion.create(**chatcompletion_kwargs)
else:
raise e

# register the request and response
_kwargs = deepcopy(kwargs)
recorder.regist_llm_inout(messages=_kwargs.pop('messages',None),
functions=_kwargs.pop('functions',None),
function_call=_kwargs.pop('function_call',None),
model = _kwargs.get('model',None),
stop = _kwargs.get('stop',None),
other_args = _kwargs,
output_data = json5.loads(str(response)))
# import pdb; pdb.set_trace()
recorder.regist_llm_inout(llm_query_id = llm_query_id,
messages=regist_kwargs.pop('messages',None),
functions=regist_kwargs.pop('functions',None),
function_call=regist_kwargs.pop('function_call',None),
model = regist_kwargs.pop('model',None),
stop = regist_kwargs.pop('stop',None),
other_args = regist_kwargs,
output_data = response)
if function_call_check:
if 'function_call' not in response['choices'][0]['message']:
logger.info("FunctionCallSchemaError")
raise FunctionCallSchemaError(f"No function call found in the response: {response['choices'][0]['message']} ")
# verify the schema of the function call if exists
function_schema = None
Expand All @@ -133,10 +155,10 @@ def openai_chatcompletion_request(*,function_call_check=True,**kwargs):
else:
kwargs['messages'][-1]['content'] = function_schema_error

logger.info("not found in the provided functions")
raise FunctionCallSchemaError(f"Function {response['choices'][0]['message']['function_call']['name']} not found in the provided functions: {list(map(lambda x:x['name'],kwargs['functions']))}")

arguments,response = load_args_with_schema_validation(function_schema,response['choices'][0]['message']['function_call']['arguments'],kwargs['messages'],return_response=True,response=response)


return response
9 changes: 7 additions & 2 deletions XAgent/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,13 @@ def __delattr__(self, key):
del self[key]
else:
raise AttributeError(f"'DotDict' object has no attribute '{key}'")
def to_dict(self):
return self
def to_dict(self, safe=False):
if safe:
right_value = deepcopy(self)
right_value.pop("openai_keys","")
return right_value
else:
return self

def reload(self,config_file='config.yml'):
self.__init__(**yaml.load(open(config_file, 'r'), Loader=yaml.FullLoader))
Expand Down
113 changes: 68 additions & 45 deletions XAgent/running_recorder.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import os
import time
import json
import yaml
import uuid
import logging
from colorama import Fore, Style
from XAgent.loggers.logs import logger
from XAgent.workflow.base_query import AutoGPTQuery
from XAgent.config import XAgentConfig

from XAgent.config import XAgentConfig, CONFIG

def dump_common_things(object):
if type(object) in [str, int, float, bool]:
Expand Down Expand Up @@ -34,7 +35,6 @@ def __init__(self, record_root_dir="./running_records/"):

self.newly_start = True

self.llm_interface_id = 0
self.tool_server_interface_id = 0

self.tool_call_id = 0
Expand Down Expand Up @@ -65,8 +65,8 @@ def regist_plan_modify(self, refine_function_name, refine_function_input, refine

self.plan_refine_id += 1

def regist_llm_inout(self, messages, functions, function_call, model, stop, other_args, output_data):
with open(os.path.join(self.record_root_dir, "LLM_inout_pair", f"{self.llm_interface_id:05d}.json"),
def regist_llm_inout(self, llm_query_id, messages, functions, function_call, model, stop, other_args, output_data):
with open(os.path.join(self.record_root_dir, "LLM_inout_pair", f"{llm_query_id:05d}.json"),
"w") as writer:
llm_inout_record = {
"input": {
Expand All @@ -78,14 +78,14 @@ def regist_llm_inout(self, messages, functions, function_call, model, stop, othe
"other_args": dump_common_things(other_args),
},
"output": dump_common_things(output_data),
"llm_interface_id": self.llm_interface_id,
"llm_interface_id": llm_query_id,
}
json.dump(llm_inout_record, writer, indent=2, ensure_ascii=False)
self.llm_server_cache.append(llm_inout_record)
logger.typewriter_log("LLM inout registed:",Fore.RED, f"query-id={llm_query_id}",level=logging.DEBUG)

self.llm_interface_id += 1

def query_llm_inout(self, restrict_cache_query, messages, functions, function_call, model, stop, other_args):
def query_llm_inout(self, llm_query_id, messages, functions, function_call, model, stop, other_args):
if self.newly_start:
return None
input_data = {
Expand All @@ -96,17 +96,20 @@ def query_llm_inout(self, restrict_cache_query, messages, functions, function_ca
"stop": dump_common_things(stop),
"other_args": dump_common_things(other_args),
}
for cache in self.llm_server_cache:
if input_data == cache["input"]:
if restrict_cache_query and self.llm_interface_id != cache["llm_interface_id"]:
continue
logger.typewriter_log(
"get a llm_server response from Record",
Fore.RED,
)
# import pdb; pdb.set_trace()
return cache["output"]
return None
if llm_query_id >= len(self.llm_server_cache):
logger.typewriter_log("Reach the max length of record: exit!!")
exit()
cache = self.llm_server_cache[llm_query_id]
# import pdb; pdb.set_trace()
if input_data == cache["input"]:
logger.typewriter_log(
"get a llm_server response from Record",
Fore.BLUE,
f"query-id={llm_query_id}"
)
return cache["output"]
assert False, f"{llm_query_id} didn't find output"


def regist_tool_call(self, tool_name, tool_input, tool_output, tool_status_code, thought_data=None):
os.makedirs(os.path.join(self.record_root_dir, self.now_subtask_id), exist_ok=True)
Expand Down Expand Up @@ -139,70 +142,90 @@ def regist_tool_server(self, url, payload, output):
def query_tool_server_cache(self, url, payload):
if self.newly_start:
return None
for cache in self.tool_server_cache:
# import pdb; pdb.set_trace()
if cache["url"] == url.split("/")[-1] and cache["payload"] == dump_common_things(payload):
logger.typewriter_log(
"get a tool_server response from Record",
Fore.RED,
cache["url"],
)
return cache["tool_output"]

return None
assert self.tool_server_interface_id < len(self.tool_server_cache), "Running Exists Record Saved Region"
cache = self.tool_server_cache[self.tool_server_interface_id]
# import pdb; pdb.set_trace()
if cache["url"] == url.split("/")[-1] and cache["payload"] == dump_common_things(payload):
logger.typewriter_log(
"get a tool_server response from Record",
Fore.BLUE,
cache["url"],
)
return cache["tool_output"]

assert False

def regist_query(self, query):
with open(os.path.join(self.record_root_dir, f"query.json"), "w") as writer:
json.dump(query.to_json(), writer, indent=2, ensure_ascii=False)


def get_query(self):
logger.typewriter_log(
"load a query from Record",
Fore.RED,
Fore.BLUE,
)
return self.query

def regist_config(self, config: XAgentConfig):
with open(os.path.join(self.record_root_dir, f"config.json"), "w") as writer:
json.dump(config.to_dict(), writer, indent=2, ensure_ascii=False)
with open(os.path.join(self.record_root_dir, f"config.yml"), "w") as writer:
writer.write(yaml.safe_dump(dict(config.to_dict(safe=True)), allow_unicode=True))

def get_config(self):
logger.typewriter_log(
"load a config from Record",
Fore.RED,
Fore.BLUE,
)
return self.config

def regist_father_info(self, record_dir):
with open(os.path.join(self.record_root_dir, f"This-Is-A-Reload-Run.yml"), "w") as writer:
writer.write(yaml.safe_dump({
"load_record_dir": record_dir,
}, allow_unicode=True))


def load_from_disk(self, record_dir):
logger.typewriter_log(
"load from a disk record",
Fore.RED,
"load from a disk record, overwrite all the existing config-info",
Fore.BLUE,
record_dir,
)

self.regist_father_info(record_dir)
self.newly_start = False

for dir_name in os.listdir(record_dir):
if dir_name == "query.json":
with open(os.path.join(record_dir, dir_name), "r") as reader:
self.query_json = json.load(reader)
self.query = AutoGPTQuery.from_json(self.query_json)
elif dir_name == "config.json":
with open(os.path.join(record_dir, dir_name), "r") as reader:
self.config_json = json.load(reader)
self.config = XAgentConfig()
self.config.merge_from_dict(self.config_json)
elif dir_name == "config.yml":
CONFIG.reload(os.path.join(record_dir, dir_name))
elif dir_name == "LLM_inout_pair":
inout_count = len(os.listdir(os.path.join(record_dir, dir_name)))
self.llm_server_cache = [None]*inout_count
for file_name in os.listdir(os.path.join(record_dir, dir_name)):
inout_id = int(file_name.split(".")[0])
with open(os.path.join(record_dir, dir_name, file_name), "r") as reader:
llm_pair = json.load(reader)
self.llm_server_cache.append(llm_pair)
self.llm_server_cache[inout_id] = llm_pair
logger.typewriter_log(
f"Record contain {inout_count} LLM inout",
Fore.BLUE,
)
elif dir_name == "tool_server_pair":
inout_count = len(os.listdir(os.path.join(record_dir, dir_name)))
self.tool_server_cache = [None]*inout_count
for file_name in os.listdir(os.path.join(record_dir, dir_name)):
inout_id = int(file_name.split(".")[0])
with open(os.path.join(record_dir, dir_name, file_name), "r") as reader:
tool_pair = json.load(reader)
self.tool_server_cache.append(tool_pair)
else:
self.tool_server_cache[inout_id] = tool_pair
logger.typewriter_log(
f"Record contain {len(os.listdir(os.path.join(record_dir, dir_name)))} Tool call",
Fore.BLUE,
)
elif os.path.isdir(os.path.join(record_dir, dir_name)):
for file_name in os.listdir(os.path.join(record_dir, dir_name)):
if file_name.startswith("plan_refine"):
with open(os.path.join(record_dir, dir_name, file_name)) as reader:
Expand Down
Loading

0 comments on commit ff52c5f

Please sign in to comment.