Skip to content

Commit

Permalink
fix pre-commit config and auto format all files
Browse files Browse the repository at this point in the history
  • Loading branch information
JianxinMa committed Apr 24, 2024
1 parent 9165f25 commit 01a6669
Show file tree
Hide file tree
Showing 89 changed files with 713 additions and 1,603 deletions.
5 changes: 3 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,17 @@ repos:
rev: 5.0.4
hooks:
- id: flake8
args: ["--max-line-length=300"]
args: ["--max-line-length=300"] # TODO: Set to 120 and `pre-commit run --all-files`.
- repo: https://github.com/PyCQA/isort.git
rev: 5.11.5
hooks:
- id: isort
args: ["--line-length", "120"]
- repo: https://github.com/pre-commit/mirrors-yapf.git
rev: v0.32.0
hooks:
- id: yapf
args: ["--style={based_on_style: google, column_limit: 300, indent_width: 4}"]
args: ["--style", "{based_on_style: google, column_limit: 120}", "-i"]
- repo: https://github.com/pre-commit/pre-commit-hooks.git
rev: v4.3.0
hooks:
Expand Down
15 changes: 5 additions & 10 deletions benchmark/code_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,21 +34,17 @@
# plt.rcParams['axes.unicode_minus'] = False
# ````
def fix_matplotlib_cjk_font_issue():
local_ttf = os.path.join(
os.path.abspath(
os.path.join(matplotlib.matplotlib_fname(), os.path.pardir)),
'fonts', 'ttf', 'simhei.ttf')
local_ttf = os.path.join(os.path.abspath(os.path.join(matplotlib.matplotlib_fname(), os.path.pardir)), 'fonts',
'ttf', 'simhei.ttf')
if not os.path.exists(local_ttf):
logging.warning(
f'Missing font file `{local_ttf}` for matplotlib. It may cause some error when using matplotlib.'
)
f'Missing font file `{local_ttf}` for matplotlib. It may cause some error when using matplotlib.')


def start_kernel(pid):
fix_matplotlib_cjk_font_issue()

connection_file = os.path.join(WORK_DIR,
f'kernel_connection_file_{pid}.json')
connection_file = os.path.join(WORK_DIR, f'kernel_connection_file_{pid}.json')
launch_kernel_script = os.path.join(WORK_DIR, f'launch_kernel_{pid}.py')
for f in [connection_file, launch_kernel_script]:
if os.path.exists(f):
Expand Down Expand Up @@ -213,8 +209,7 @@ def _code_interpreter(code: str, timeout, clear=False):
finished = True
except Exception:
text = 'The code interpreter encountered an unexpected error.'
logging.warning(''.join(
traceback.format_exception(*sys.exc_info())))
logging.warning(''.join(traceback.format_exception(*sys.exc_info())))
finished = True
if text:
result += f'\n\n{msg_type}:\n\n```\n{text}\n```'
Expand Down
13 changes: 3 additions & 10 deletions benchmark/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,7 @@
'internlm': InternLMReActParser,
}

model_map = {
'qwen': Qwen,
'llama': LLM,
'internlm': LLM,
'qwen-vl-chat': QwenVL
}
model_map = {'qwen': Qwen, 'llama': LLM, 'internlm': LLM, 'qwen-vl-chat': QwenVL}

model_type_map = {
'qwen-72b-chat': 'qwen',
Expand Down Expand Up @@ -52,14 +47,12 @@


def get_react_prompt(model_name, query, lang, upload_fname_list):
react_prompt_cls = react_prompt_map.get(model_type_map[model_name],
QwenReAct)
react_prompt_cls = react_prompt_map.get(model_type_map[model_name], QwenReAct)
return react_prompt_cls(query, lang, upload_fname_list)


def get_react_parser(model_name):
react_parser_cls = react_parser_map.get(model_type_map[model_name],
ReActParser)
react_parser_cls = react_parser_map.get(model_type_map[model_name], ReActParser)
return react_parser_cls()


Expand Down
101 changes: 29 additions & 72 deletions benchmark/inference_and_execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@
import prettytable
import tqdm
from code_interpreter import code_interpreter
from config import (get_model, get_react_parser, get_react_prompt,
model_path_map)
from config import get_model, get_react_parser, get_react_prompt, model_path_map
from datasets import load_dataset
from metrics.code_execution import eval_code_execution_rate
from metrics.gsm8k import eval_gsm8k_acc, is_correct
Expand Down Expand Up @@ -45,11 +44,9 @@ def llm_with_plugin(args, query, item=None, exec_limit=3):
exec_count = 0

# Build ReAct prompt
upload_fname_list = item[
'input_file_path'] if item and 'input_file_path' in item else []
upload_fname_list = item['input_file_path'] if item and 'input_file_path' in item else []
lang = item['lang'] if item and 'lang' in item else 'en'
react_prompt_obj = get_react_prompt(args.model, query, lang,
upload_fname_list)
react_prompt_obj = get_react_prompt(args.model, query, lang, upload_fname_list)
planning_prompt = react_prompt_obj.build_prompt()

# Execute the code when providing the first action in the query
Expand All @@ -64,22 +61,17 @@ def llm_with_plugin(args, query, item=None, exec_limit=3):
text = ''
while exec_count < exec_limit:
stop_words_list = react_prompt_obj.get_stop_words_list()
output = text_completion(args.llm,
planning_prompt + text,
stop_words=stop_words_list)
output = text_completion(args.llm, planning_prompt + text, stop_words=stop_words_list)

if args.gen_only:
text += output
break

react_parser = get_react_parser(args.model)
action, action_input, output = react_parser.parse_latest_plugin_call(
output)
action, action_input, output = react_parser.parse_latest_plugin_call(output)
if action:
action_input = replace_upload_fname(action_input,
upload_fname_list)
observation = call_tool(action, [action_input],
clear=(exec_count == 0))
action_input = replace_upload_fname(action_input, upload_fname_list)
observation = call_tool(action, [action_input], clear=(exec_count == 0))
output += react_prompt_obj.build_observation(observation)
text += output
exec_count += 1
Expand Down Expand Up @@ -114,10 +106,7 @@ def call_tool(plugin_name, plugin_args_list, clear=False):
def process_code_interpreter(item, writer):
query = item['query']
exec_limit = 3 if 'visualization' in item['tags'] else 1
response = llm_with_plugin(args=args,
query=query,
item=item,
exec_limit=exec_limit)
response = llm_with_plugin(args=args, query=query, item=item, exec_limit=exec_limit)
item['gen'] = response

writer.write(json.dumps(item, ensure_ascii=False) + '\n')
Expand All @@ -140,10 +129,7 @@ def sequential_processing(args, data_list, process_func, writer):
process_func(item, writer)


process_func_map = {
'gsm8k': process_gsm8k,
'visualization': process_code_interpreter
}
process_func_map = {'gsm8k': process_gsm8k, 'visualization': process_code_interpreter}


def gather_eval_result(model_name):
Expand All @@ -161,36 +147,26 @@ def gather_eval_result(model_name):

def eval_metrics(args, test_set, full_output_fname):
# metrics
assert os.path.exists(
full_output_fname), f'Not Found File {full_output_fname}.'
assert os.path.exists(full_output_fname), f'Not Found File {full_output_fname}.'
inference_res = load_jsonl(full_output_fname)
assert len(inference_res) == len(
test_set
), f'There are still {len(test_set)-len(inference_res)} cases left.'
assert len(inference_res) == len(test_set), f'There are still {len(test_set)-len(inference_res)} cases left.'

abs_output_fname = os.path.join(os.path.dirname(os.path.abspath(__file__)),
full_output_fname)
abs_output_fname = os.path.join(os.path.dirname(os.path.abspath(__file__)), full_output_fname)
if args.task == 'gsm8k':
math_code_correctness = eval_gsm8k_acc(abs_output_fname)
global_eval_result['code_correctness'].update(math_code_correctness)
else:
code_executability = eval_code_execution_rate(abs_output_fname,
args.task, args.model)
code_executability = eval_code_execution_rate(abs_output_fname, args.task, args.model)
global_eval_result['code_executability'].update(code_executability)
if args.task in ['all_ci', 'visualization'
] and not args.eval_code_exec_only:
visualization_code_correctness = eval_visualization_acc(
abs_output_fname, args.model, args.vis_judger)
global_eval_result['code_correctness'].update(
visualization_code_correctness)
if args.task in ['all_ci', 'visualization'] and not args.eval_code_exec_only:
visualization_code_correctness = eval_visualization_acc(abs_output_fname, args.model, args.vis_judger)
global_eval_result['code_correctness'].update(visualization_code_correctness)


def main(args):
current_dir = os.getcwd()
os.makedirs(args.output_path, exist_ok=True)
full_output_fname = os.path.join(
args.output_path,
(args.output_fname or f'{args.task}_{args.model}_res.jsonl'))
full_output_fname = os.path.join(args.output_path, (args.output_fname or f'{args.task}_{args.model}_res.jsonl'))

if not os.path.exists(full_output_fname):
with open(full_output_fname, 'w'):
Expand All @@ -202,28 +178,21 @@ def main(args):
test_set = dataset['test']
else:
eval_data_path = os.path.join(args.input_path, args.input_fname)
test_set = [
item for item in load_jsonl(eval_data_path)
if args.task in item['tags']
]
test_set = [item for item in load_jsonl(eval_data_path) if args.task in item['tags']]
logging.info(f'Test set: {len(test_set)}')

if args.eval_only:
eval_metrics(args, test_set, full_output_fname)
else:
key = 'question' if args.task == 'gsm8k' else 'query'
cache_question = [item[key] for item in load_jsonl(full_output_fname)
] if not args.force else []
data_list = [
item for item in test_set if item[key] not in cache_question
]
cache_question = [item[key] for item in load_jsonl(full_output_fname)] if not args.force else []
data_list = [item for item in test_set if item[key] not in cache_question]
logging.info(f'Left cases: {len(data_list)}')

# inference
writer_mode = 'w' if args.force else 'a'
f_output = open(full_output_fname, writer_mode, encoding='utf-8')
process_func = process_func_map.get(args.task,
process_code_interpreter)
process_func = process_func_map.get(args.task, process_code_interpreter)
sequential_processing(args, data_list, process_func, f_output)
f_output.close()

Expand All @@ -236,33 +205,21 @@ def main(args):

def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--model',
type=str,
default='qwen-14b-chat',
choices=list(model_path_map.keys()))
parser.add_argument('--task',
type=str,
default='all',
choices=['all', 'gsm8k', 'visualization', 'general'])
parser.add_argument('--model', type=str, default='qwen-14b-chat', choices=list(model_path_map.keys()))
parser.add_argument('--task', type=str, default='all', choices=['all', 'gsm8k', 'visualization', 'general'])
parser.add_argument('--output-path', type=str, default='output_data')
parser.add_argument('--input-path', type=str, default='eval_data')
parser.add_argument('-o', '--output-fname', type=str, default='')
parser.add_argument('-i',
'--input-fname',
type=str,
default='eval_code_interpreter_v1.jsonl')
parser.add_argument('-i', '--input-fname', type=str, default='eval_code_interpreter_v1.jsonl')
parser.add_argument('-f', '--force', action='store_true', default=False)
parser.add_argument('--eval-only', action='store_true', default=False)
parser.add_argument('--eval-code-exec-only',
action='store_true',
default=False)
parser.add_argument('--eval-code-exec-only', action='store_true', default=False)
parser.add_argument('--gen-exec-only', action='store_true', default=False)
parser.add_argument('--gen-only', action='store_true', default=False)
parser.add_argument(
'--vis-judger',
type=str,
default="'gpt-4-vision-preview'",
choices=['gpt-4-vision-preview', 'qwen-vl-chat', 'qwen-vl-plus'])
parser.add_argument('--vis-judger',
type=str,
default="'gpt-4-vision-preview'",
choices=['gpt-4-vision-preview', 'qwen-vl-chat', 'qwen-vl-plus'])
args = parser.parse_args()
return args

Expand Down
54 changes: 17 additions & 37 deletions benchmark/metrics/code_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,7 @@ def postprocess_code(gen_code, line):
first_action_code = get_action_input_code(line['query'])
gen_code = first_action_code + gen_code

upload_fname_list = line[
'input_file_path'] if line and 'input_file_path' in line else []
upload_fname_list = line['input_file_path'] if line and 'input_file_path' in line else []
gen_code = replace_upload_fname(gen_code, upload_fname_list)

if 'def solution()' in gen_code:
Expand All @@ -79,9 +78,7 @@ def postprocess_code(gen_code, line):
return gen_code


def get_action_input_code(text,
model_name='qwen-14b-chat',
extract_first_code=False):
def get_action_input_code(text, model_name='qwen-14b-chat', extract_first_code=False):
action_input_list = []
tmp = text
react_parser = get_react_parser(model_name)
Expand Down Expand Up @@ -125,9 +122,7 @@ def eval_code_execution_rate(output_fname,
line['code_error_info'] = ''

# get Action Input code from response
gen_code = get_action_input_code(line['gen'],
model_name=model_name,
extract_first_code=extract_first_code)
gen_code = get_action_input_code(line['gen'], model_name=model_name, extract_first_code=extract_first_code)

if not gen_code:
line['missing_code'] = True
Expand Down Expand Up @@ -163,30 +158,22 @@ def eval_code_execution_rate(output_fname,
break

# double check
observation = get_react_parser(model_name).get_first_observation(
line['gen'])
observation = get_react_parser(model_name).get_first_observation(line['gen'])
if line['executable_code'] and ('error:' in observation):
logging.warning(
'The code executes correctly, but it has an error in IPython!')
logging.warning('The code executes correctly, but it has an error in IPython!')
logging.warning(f'Code:\n{gen_code}')
logging.warning(f'IPython error info:\n{observation}')
logging.info('=' * 60)
elif not line['executable_code'] and not ('error:' in observation):
logging.warning(
'The code has an execution error, but it runs correctly in IPython!'
)
logging.warning('The code has an execution error, but it runs correctly in IPython!')
logging.warning(f'Code:\n{gen_code}')
logging.warning(f"Exec error info:\n{line['code_error_info']}")
logging.warning(f'IPython observation:\n{observation}')
logging.info('=' * 60)

# save error data
error_data_list = [
item for item in data_list
if not item['executable_code'] or item['missing_code']
]
error_data_output_fname = os.path.splitext(
output_fname)[0] + '_exec_error.jsonl'
error_data_list = [item for item in data_list if not item['executable_code'] or item['missing_code']]
error_data_output_fname = os.path.splitext(output_fname)[0] + '_exec_error.jsonl'
save_jsonl(error_data_list, error_data_output_fname)

log_result(data_list)
Expand All @@ -210,8 +197,7 @@ def log_result(data_list, verbose=True):
logging.info('\n' + line['code'])

logging.info(f'Exec Result {line_id}'.center(60, '-'))
prefix_info = 'Exec Success' if line[
'executable_code'] else 'Exec Error: '
prefix_info = 'Exec Success' if line['executable_code'] else 'Exec Error: '
exec_info = prefix_info + line['code_error_info']
logging.info(exec_info)

Expand All @@ -227,31 +213,25 @@ def log_result(data_list, verbose=True):
logging.info(f'task: {key}'.center(60, '='))
key_item_list = [item for item in data_list if key in item['tags']]
all_count = len(key_item_list)
missing_code_count = len(
[item for item in key_item_list if item['missing_code']])
executable_code_count = len(
[item for item in key_item_list if item['executable_code']])
missing_code_count = len([item for item in key_item_list if item['missing_code']])
executable_code_count = len([item for item in key_item_list if item['executable_code']])

logging.info(f'All Test: {all_count}')
logging.info(f'Missing Code: {missing_code_count}')
logging.info(f'Predict Exec Success: {executable_code_count}')
logging.info('Codes available && Execution Rate: {:.2f}'.format(
executable_code_count / (all_count - missing_code_count) * 100))
logging.info('Execution Rate: {:.2f}'.format(executable_code_count /
all_count * 100))
logging.info('Codes available && Execution Rate: {:.2f}'.format(executable_code_count /
(all_count - missing_code_count) * 100))
logging.info('Execution Rate: {:.2f}'.format(executable_code_count / all_count * 100))
logging.info('Non-executable rate: {:.2f}'.format(
(all_count - missing_code_count - executable_code_count) /
all_count * 100))
logging.info('Missing code rate: {:.2f}'.format(missing_code_count /
all_count * 100))
(all_count - missing_code_count - executable_code_count) / all_count * 100))
logging.info('Missing code rate: {:.2f}'.format(missing_code_count / all_count * 100))

if key != 'all_ci':
code_executability[key] = executable_code_count / all_count * 100

if verbose:
logging.info('Error List: ')
error_list = [(item['idx'], item['code_error_info'])
for item in key_item_list if item['code_error_info']]
error_list = [(item['idx'], item['code_error_info']) for item in key_item_list if item['code_error_info']]
error_list.sort(key=lambda x: x[1])
for x in error_list:
logging.info(x)
Loading

0 comments on commit 01a6669

Please sign in to comment.