-
Notifications
You must be signed in to change notification settings - Fork 1k
/
ast_eval_tf.py
171 lines (153 loc) · 6.05 KB
/
ast_eval_tf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
# Copyright 2023 https://github.com/ShishirPatil/gorilla
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import json
from tree_sitter import Language, Parser
# Get all the subtrees given a root_node
def get_all_sub_trees(root_node):
node_stack = []
sub_tree_sexp_list = []
depth = 1
text = root_node.text
node_stack.append([root_node, depth])
while len(node_stack) != 0:
cur_node, cur_depth = node_stack.pop()
if cur_node.child_count > 0:
sub_tree_sexp_list.append([cur_node.sexp(), cur_depth, cur_node, cur_node.children[0].text])
else:
sub_tree_sexp_list.append([cur_node.sexp(), cur_depth, cur_node, None])
for child_node in cur_node.children:
if len(child_node.children) != 0:
depth = cur_depth + 1
node_stack.append([child_node, depth])
return sub_tree_sexp_list
# Parse the program into AST trees
def ast_parse(candidate, lang='python'):
LANGUAGE = Language('codebleu/parser/my-languages.so', lang)
parser = Parser()
parser.set_language(LANGUAGE)
candidate_tree = parser.parse(bytes(candidate,'utf8')).root_node
return candidate_tree
# Get all the arguments in the ast tree
def get_args(node):
if node.child_count == 0:
return []
args_list = []
for child in node.children[0].children[0].children[1].children:
if 'model=' in child.text.decode() or 'model =' in child.text.decode():
args_list.append(child.children[2].text)
elif child.text.decode() != "(" and child.text.decode() != ")" and child.text.decode() != ",":
args_list.append(child.text)
return args_list
# Check if there is an api match
def ast_check(candidate_subtree_list, base_tree_list):
for idx, base_tree in enumerate(base_tree_list):
if base_tree.children[0].children[0].child_count == 0:
continue
api_name = base_tree.children[0].children[0].children[0].text
for candidate_tree in candidate_subtree_list:
if candidate_tree[3] == api_name:
break
# Now we have a sub-tree
candidate_tree = candidate_tree[2]
args_list = get_args(base_tree)
if len(args_list) == 0:
continue
ast_match = True
for arg in args_list:
if arg.decode().lstrip("'").rstrip("'") not in candidate_tree.text.decode():
ast_match = False
break
if ast_match:
return idx
return -1
# Parse the dataset
def parse_dataset(args):
# Read the api datasest
api_database = []
with open(args.api_dataset, 'r') as f:
for line in f:
api_database.append(json.loads(line))
# Read the question answer pair datasest
qa_pairs = []
with open(args.apibench, 'r') as f:
for line in f:
qa_pairs.append(json.loads(line)["api_data"])
# Read the language model response datasest
llm_responses = []
with open(args.llm_responses, 'r') as f:
for line in f:
llm_responses.append(json.loads(line))
# Parse all apis to ast trees
ast_database = []
for data in api_database:
ast_tree = ast_parse(data['api_call'])
ast_database.append(ast_tree)
return api_database, qa_pairs, llm_responses, ast_database
def main(args):
# Read datsets
api_database, qa_pairs, llm_responses, ast_database = parse_dataset(args)
# Check correctness
total_correct = 0
total_hallucination = 0
for idx, response in enumerate(llm_responses):
try:
output = response['text']
except:
print('Error: cannot parse line ', idx)
continue
# Index the "api_call" domain
output = output.split("api_call")
if len(output) == 1:
# print('Error: line ', idx, ' is not the right format')
# continue
api_call = output[0]
else:
# Parse the output
output = output[1].split("api_provider")[0]
if ":" not in output:
start = 0
else:
start = output.index(":")
if ")" not in output:
end = -2
else:
end = output.rindex(")")
api_call = output[start+2:end+1]
# Parse the api_call into AST tree
ast_tree = ast_parse(api_call)
# Search for a subtree
ast_subtree_list = get_all_sub_trees(ast_tree)
# Check which ast tree is matching
database_index = ast_check(ast_subtree_list, ast_database)
# We cannot index this ast in our database
if database_index == -1:
total_hallucination += 1
continue
# We index our reference api_call
ref_api_call = api_database[database_index]
# Check for functionality
if ref_api_call['domain'] == qa_pairs[response['question_id'] - 1]['domain']:
total_correct += 1
else:
pass
print('Final Functionality accuracy: ', total_correct / len(llm_responses))
print('Final hallucination: ', total_hallucination/len(llm_responses))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--api_dataset", type=str, default=None, help="path to your api dataset")
parser.add_argument("--apibench", type=str, default=None, help="path to your apibench dataset including the question and answer pairs")
parser.add_argument("--llm_responses", type=str, default=None, help="path to the language model responses")
args = parser.parse_args()
main(args)