Skip to content

Commit 698240a

Browse files
author
youfeng
committedNov 1, 2023
fx bug
1 parent 4b24b97 commit 698240a

File tree

4 files changed

+15
-8
lines changed

4 files changed

+15
-8
lines changed
 

‎code/结婚买房代代韭菜/serving/fin_qa/art.py

+2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import re
22

33
from .keywords import comp_short_dict
4+
from .query_analyze import query_analyze
5+
from .utils import lcs_sub
46

57

68
'''

‎code/结婚买房代代韭菜/serving/fin_qa/config.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@
1717
# data
1818
HTML_PATH = "./test_data/allhtml"
1919
PDF_PATH = "D:\\data\\chatglm_llm_fintech_raw_dataset\\allpdf"
20-
PDF_IDX_PATH = "./test_data/pdf_name.txt"
20+
PDF_IDX_PATH = "./test_data/C-list-pdf-name.txt"
2121
TXT_PATH = "D:\\data\\chatglm_llm_fintech_raw_dataset\\alltxt"
2222
STOPWORDS_PATH = "fin_qa/resources/stopwords.txt"
23-
QUESTION_PATH = "./test_data/queries.json"
23+
QUESTION_PATH = "./test_data/C-list-question.json"
2424
OUTPUT_PATH = "result.json"
2525

2626
# DB

‎code/结婚买房代代韭菜/serving/fin_qa/normalize/normalize_utils.py

+2
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ def pack_sql_res(sql, query, query_analyze_result, type_, res):
3030
res_T = [[res[j][i] for j in range(len(res))] for i in range(len(res[0]))]
3131
res_dic = {s: r for s, r in zip(selects, res_T)}
3232

33+
# print(f"\n\n\n res: {res}\n res_dic: {res_dic}")
34+
3335
# TODO:处理小数位数问题,比较复杂,这个问题应该在建库时作一个原始字符串段可能好一点
3436
# 拿到所有年份所有公司的年报小数位数
3537
# dot_bits = {}

‎code/结婚买房代代韭菜/serving/predict.py

+9-6
Original file line numberDiff line numberDiff line change
@@ -83,9 +83,11 @@ def normalize(args, input_path, output_path):
8383
exe_sql = translate_sql(sql)
8484
sql_res = cursor.execute(exe_sql).fetchall()
8585

86+
# print(f"\n# query:{query}\ntype:{type_}\nexe_sql:{exe_sql}\nsql_res:{sql_res}")
87+
8688
if args.mode == "listC_final":
8789
# 法定代表人使用规则生成答案
88-
res = pack_normalize_res(type_, sql, gen_sql_res_json(cursor.description, sql_res))
90+
res = pack_sql_res(sql, query, query_analyze_result, type_, sql_res)
8991
if "法定代表人" in query and len(sql_res) > 1:
9092

9193
line["norm_prompt"] = str(res)
@@ -111,6 +113,7 @@ def normalize(args, input_path, output_path):
111113
line["answer"] = gen_ans
112114
except Exception as e:
113115
print(query)
116+
print(exe_sql)
114117
print(sql_res)
115118
print("ERR: ", e)
116119
line["type"] = "type3"
@@ -144,12 +147,12 @@ def solve_type3(args, input_path):
144147
def predict(args):
145148
# 一共加载四次模型
146149

147-
# 1. 路由问题类型
148-
router(args, ROUTER_FILE_PATH)
150+
# # 1. 路由问题类型
151+
# router(args, ROUTER_FILE_PATH)
149152

150-
# 2. 对部分问题作nl2sql
151-
reset_transformer_chatglm2(pre_seq_len=NL2SQL_PRE_SEQ_LEN, checkpoint_path=NL2SQL_CHECKPOINT_PATH)
152-
nl2sql(args, ROUTER_FILE_PATH, SQL_FILE_PATH)
153+
# # 2. 对部分问题作nl2sql
154+
# reset_transformer_chatglm2(pre_seq_len=NL2SQL_PRE_SEQ_LEN, checkpoint_path=NL2SQL_CHECKPOINT_PATH)
155+
# nl2sql(args, ROUTER_FILE_PATH, SQL_FILE_PATH)
153156

154157
# 3. 对于使用sql进行查询的结果进行回答问题
155158
reset_transformer_chatglm2(pre_seq_len=NORMALIZE_PRE_SEQ_LEN, checkpoint_path=NORMALIZE_CHECKPOINT_PATH)

0 commit comments

Comments
 (0)
Please sign in to comment.