Skip to content

Commit

Permalink
doc & fix: Update tool calling benchmark examples and fix bugs (#1419)
Browse files Browse the repository at this point in the history
  • Loading branch information
harryeqs authored Jan 8, 2025
1 parent 0f95ba9 commit 837096e
Show file tree
Hide file tree
Showing 5 changed files with 319 additions and 15 deletions.
5 changes: 5 additions & 0 deletions camel/benchmarks/apibank.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@

logger = logging.getLogger(__name__)

# Add current folder to sys.path to enable relative import
current_folder = os.getcwd()
if current_folder not in sys.path:
sys.path.append(current_folder)


def process_messages(
chat_history: List[Dict[str, Any]],
Expand Down
12 changes: 8 additions & 4 deletions camel/benchmarks/apibench.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,24 +29,28 @@

logger = logging.getLogger(__name__)


# Mapping of dataset names to file names
# 'Oracle' retriver used here which means all the full
# API documentation will be included in the prompt
dataset_mapping = {
"huggingface": {
"api": "huggingface_api.jsonl",
"eval": "huggingface_eval.json",
"train": "huggingface_train.json",
"questions": "questions_huggingface_0_shot.jsonl",
"questions": "questions_huggingface_oracle.jsonl",
},
"tensorflowhub": {
"api": "tensorflowhub_api.jsonl",
"eval": "tensorflow_eval.json",
"train": "tensorflow_train.json",
"questions": "questions_tensorflowhub_0_shot.jsonl",
"questions": "questions_tensorflowhub_oracle.jsonl",
},
"torchhub": {
"api": "torchhub_api.jsonl",
"eval": "torchhub_eval.json",
"train": "torchhub_train.json",
"questions": "questions_torchhub_0_shot.jsonl",
"questions": "questions_torchhub_oracle.jsonl",
},
}

Expand Down Expand Up @@ -173,7 +177,7 @@ def download(self):
)

repo = "ShishirPatil/gorilla"
subdir = "eval/eval-data/questions"
subdir = "/gorilla/eval/eval-data/questions"
data_dir = self.data_dir

download_github_subdirectory(repo, subdir, data_dir)
Expand Down
58 changes: 53 additions & 5 deletions examples/benchmarks/apibank.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,60 @@

from camel.agents import ChatAgent
from camel.benchmarks import APIBankBenchmark
from camel.benchmarks.apibank import Evaluator

# Set up the agent to be benchmarked
agent = ChatAgent()

benchmark = APIBankBenchmark(
data_dir="./APIBankDatasets", save_to="./APIBankResults.jsonl"
)
# Set up the APIBench Benchmark
# Please note that the data_dir is predefined
# for better import management of the tools
benchmark = APIBankBenchmark(save_to="APIBankResults.jsonl")

# Download the benchmark data
benchmark.download()
results = benchmark.run(agent, "level-1", api_test_enabled=True)
print(results)

# Set the subset to be benchmarked
level = 'level-1'

# NOTE: You might encounter the following error when
# running the benchmark in Windows:
# UnicodeDecodeError: 'charmap' codec can't decode byte 0x81
# in position 130908: character maps to <undefined>
# To solve this issue, you can navigate to the file
# api_bank/tool_manager.py", line 30 and change the encoding
# with open(os.path.join(init_database_dir, file), 'r',
# encoding='utf-8') as f:


# Run the benchmark
result = benchmark.run(agent, level, api_test_enabled=True, subset=10)

# The following steps are only for demostration purposes,
# they have been integrated into the run method of the benchmark.
# Get the first example of the test data
example_test = list(benchmark._data.items())[0] # type: ignore[assignment] # noqa: RUF015
evaluator = Evaluator(example_test)
api_description = evaluator.get_api_description('ToolSearcher')
print('\nAPI description for ToolSearcher:\n', api_description)
'''
===============================================================================
API description for ToolSearcher:
{"name": "ToolSearcher", "description": "Searches for relevant tools in
library based on the keywords.", "input_parameters": {"keywords": {"type":
"str", "description": "The keyword to search for."}},
"output_parameters":
{"best_matchs": {"type": "Union[List[dict], dict]",
"description": "The best match tool(s)."}}}
===============================================================================
'''

# Print the final results
print("Total:", result["total"])
print("Correct:", result["correct"])
'''
===============================================================================
Total: 24
Correct: 10
===============================================================================
'''
61 changes: 58 additions & 3 deletions examples/benchmarks/apibench.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,65 @@
from camel.agents import ChatAgent
from camel.benchmarks import APIBenchBenchmark

# Set up the agent to be benchmarked
agent = ChatAgent()

# Set up the APIBench Benchmark
benchmark = APIBenchBenchmark(
data_dir="./APIBenchDatasets", save_to="./APIBench.jsonl"
data_dir="APIBenchDatasets", save_to="APIBenchResults.jsonl"
)
result = benchmark.run(agent, 'torchhub')
print(result)

# Download the benchmark data
benchmark.download()

# Set the subset to be benchmarked
subset_name = 'torchhub'

# Run the benchmark
result = benchmark.run(agent, subset_name, subset=10)

# Please note that APIBench does not use 'real function call'
# but instead includes API documentation in the questions
# for the agent to refernce.
# An example question including the API documentation is printed below.
print(
"\nExample question including API documentation:\n",
benchmark._data['questions'][0]['text'],
)
'''
===============================================================================
Example question including API documentation:
What is an API that can be used to classify sports activities in videos?\n
Use this API documentation for reference:
{"domain": "Video Classification", "framework": "PyTorch",
"functionality": "3D ResNet", "api_name": "slow_r50",
"api_call": "torch.hub.load(repo_or_dir='facebookresearch/pytorchvideo',
model='slow_r50', pretrained=True)", "api_arguments": {"pretrained": "True"},
"python_environment_requirements": ["torch", "json", "urllib",
"pytorchvideo",
"torchvision", "torchaudio", "torchtext", "torcharrow", "TorchData",
"TorchRec", "TorchServe", "PyTorch on XLA Devices"],
"example_code": ["import torch",
"model = torch.hub.load('facebookresearch/pytorchvideo',
'slow_r50', pretrained=True)",
"device = 'cpu'", "model = model.eval()", "model = model.to(device)"],
"performance": {"dataset": "Kinetics 400",
"accuracy": {"top_1": 74.58, "top_5": 91.63},
"Flops (G)": 54.52, "Params (M)": 32.45},
"description": "The 3D ResNet model is a Resnet-style video classification
network pretrained on the Kinetics 400 dataset. It is based on the
architecture from the paper 'SlowFast Networks for Video Recognition'
by Christoph Feichtenhofer et al."}}
===============================================================================
'''

print("Total:", result["total"])
print("Correct:", result["correct"])
print("Hallucination:", result["hallucination"])
'''
===============================================================================
Total: 10
Correct: 10
Hallucination: 0
===============================================================================
'''
198 changes: 195 additions & 3 deletions examples/benchmarks/nexus.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,204 @@

from camel.agents import ChatAgent
from camel.benchmarks import NexusBenchmark
from camel.benchmarks.nexus import construct_tool_descriptions

# Set up the agent to be benchmarked
agent = ChatAgent()

# Set up the Nexusraven Function Calling Benchmark
benchmark = NexusBenchmark(
data_dir="./NexusDatasets", save_to="./NexusResults.jsonl"
data_dir="NexusDatasets", save_to="NexusResults.jsonl"
)

# Download the benchmark data
benchmark.download()
results = benchmark.run(agent, "OTX")
print(results)

# Set the task (sub-dataset) to be benchmarked
task = "OTX"

# Please note that the following step is only for demostration purposes,
# it has been integrated into the run method of the benchmark.
# The tools fetched here are used to construct the prompt for the task,
# which will be passed to the agent for response.
tools = construct_tool_descriptions(task)
print('\nTool descriptions for the task:\n', tools)
'''
===============================================================================
"""
Function:
def getIndicatorForIPv6(apiKey: str, ip: str, section: str):
"""
Retrieves comprehensive information for a specific IPv6 address from the
AlienVault database.
This function allows you to obtain various types of data.
The 'general' section provides general information about the IP,
including geo data, and a list of other available sections.
'reputation' offers OTX data on malicious activity observed by
AlienVault Labs. 'geo' details more verbose geographic data such
as country code and coordinates. 'malware' reveals malware samples
connected to the IP,
and 'urlList' shows URLs associated with the IP. Lastly, 'passiveDns'
includes passive DNS information about hostnames/domains
pointing to this IP.
Args:
- apiKey: string, required, Your AlienVault API key
- ip: string, required, IPv6 address to query
- section: string, required, Specific data section to retrieve
(options: general, reputation, geo, malware, urlList, passiveDns)
"""
Function:
def getIndicatorForDomain(apiKey: str, domain: str, section: str):
"""
Retrieves a comprehensive overview for a given domain name from the
AlienVault database. This function provides various data types
about the domain. The 'general' section includes general information
about the domain, such as geo data, and lists of other available
sections. 'geo' provides detailed geographic data including country
code and coordinates. The 'malware' section indicates malware samples
associated with the domain. 'urlList' shows URLs linked to the domain,
'passiveDns' details passive DNS information about hostnames/domains
associated with the domain,
and 'whois' gives Whois records for the domain.
Args:
- apiKey: string, required, Your AlienVault API key
- domain: string, required, Domain address to query
- section: string, required, Specific data section to retrieve
(options: general, geo, malware, urlList, passiveDns, whois)
"""
Function:
def getIndicatorForHostname(apiKey: str, hostname: str, section: str):
"""
Retrieves detailed information for a specific hostname from the
AlienVault database. This function provides various data types about
the hostname. The 'general' section includes general information
about the IP, geo data, and lists of other available sections.
'geo' provides detailed geographic data including country code
and coordinates. The 'malware' section indicates malware samples
associated with the hostname. 'urlList' shows URLs linked to
the hostname, and 'passiveDns' details passive DNS information
about hostnames/domains associated with the hostname.
Args:
- apiKey: string, required, Your AlienVault API key
- hostname: string, required, Single hostname address to query
- section: string, required, Specific data section to retrieve
(options: general, geo, malware, urlList, passiveDns)
"""
Function:
def getIndicatorForFileHashes(apiKey: str, fileHash: str, section: str):
"""
Retrieves information related to a specific file hash from the
AlienVault database.
This function provides two types of data: 'general',
which includes general metadata about the file hash and a list of other
available sections for the hash; and 'analysis', which encompasses both
dynamic and static analysis of the file,
including Cuckoo analysis, exiftool, etc.
Args:
- apiKey: string, required, Your AlienVault API key
- fileHash: string, required, Single file hash to query
- section: string, required, Specific data section to retrieve
(options: general, analysis)
"""
Function:
def getIndicatorForUrl(apiKey: str, url: str, section: str):
"""
Retrieves information related to a specific URL from the AlienVault
database. This function offers two types of data: 'general',
which includes historical geographic information,
any pulses this indicator is on,
and a list of other available sections for this URL; and 'url_list',
which provides full results from AlienVault Labs URL analysis,
potentially including multiple entries.
Args:
- apiKey: string, required, Your AlienVault API key
- url: string, required, Single URL to query
- section: string, required, Specific data section to retrieve
(options: general, url_list)
"""
Function:
def getIndicatorForCVE(apiKey: str, cve: str, section: str):
"""
Retrieves information related to a specific CVE
(Common Vulnerability Enumeration)
from the AlienVault database. This function offers detailed data on CVEs.
The 'General' section includes MITRE CVE data, such as CPEs
(Common Platform Enumerations),
CWEs (Common Weakness Enumerations), and other relevant details.
It also provides information on any pulses this indicator is on,
and lists other sections currently available for this CVE.
Args:
- apiKey: string, required, Your AlienVault API key
- cve: string, required, Specific CVE identifier to query
(e.g., 'CVE-2014-0160')
- section: string, required, Specific data section to retrieve
('general' only)
"""
Function:
def getIndicatorForNIDS(apiKey: str, nids: str, section: str):
"""
Retrieves metadata information for a specific
Network Intrusion Detection System (NIDS)
indicator from the AlienVault database. This function is designed to
provide general metadata about NIDS indicators.
Args:
- apiKey: string, required, Your AlienVault API key
- nids: string, required, Specific NIDS indicator to query
(e.g., '2820184')
- section: string, required, Specific data section to retrieve
('general' only)
"""
Function:
def getIndicatorForCorrelationRules(apiKey: str, correlationRule: str,
section: str):
"""
Retrieves metadata information related to a specific Correlation Rule from
the AlienVault database. This function is designed to provide
general metadata about
Correlation Rules used in network security and event correlation.
Correlation Rules are crucial for identifying patterns and potential
security threats in network data.
Args:
- apiKey: string, required, Your AlienVault API key
- correlationRule: string, required, Specific Correlation Rule
identifier to query (e.g., '572f8c3c540c6f0161677877')
- section: string, required, Specific data section to retrieve
('general' only)
"""
===============================================================================
'''

# Run the benchmark
result = benchmark.run(agent, task, subset=10)
print("Total:", result["total"])
print("Correct:", result["correct"])
'''
===============================================================================
Total: 10
Correct: 9
===============================================================================
'''

0 comments on commit 837096e

Please sign in to comment.