Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(knext): add kag_agent #319

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix bugs
  • Loading branch information
mfz-ant committed Jul 9, 2024
commit a510ddead7a1619d99ba8fa7128b4dd6870cc772
51 changes: 33 additions & 18 deletions python/knext/knext/ca/logic/agents/divide_and_conquer.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import asyncio

from knext.ca.common.base import Question, Agent
from knext.ca.logic.modules.reasoner import AnswerQuestionWithContext
from knext.ca.logic.modules.planner import DivideQuestion, IsAtomQuestion, RewriteQuestionBasedOnDeps
from knext.ca.logic.modules.solver import SolveQuestionWithContext
from knext.ca.logic.modules.planner import DivideQuestion, RewriteQuestionBasedOnDeps
from knext.ca.logic.modules.reasoner import IsAtomQuestion


class DivideAndConquerAgent(Agent):
Expand All @@ -14,14 +15,16 @@ def __init__(
divide_question=None,
rewrite_question=None,
is_atom_question=None,
answer_parent_question=None,
answer_atom_question=None,
solve_parent_question=None,
solve_atom_question=None,
use_default_prompt_template=False,
prompt_template_dir=None,
use_en_log=True,
**kwargs
):
self.llm = llm
self.max_depth = max_depth
self.use_en_log = use_en_log

self.divide_question = divide_question if divide_question else DivideQuestion(
self.llm, use_default_prompt_template, prompt_template_dir)
Expand All @@ -32,24 +35,30 @@ def __init__(
self.is_atom_question = is_atom_question if is_atom_question else IsAtomQuestion(
self.llm, use_default_prompt_template, prompt_template_dir)

self.answer_parent_question = answer_parent_question if answer_parent_question else AnswerQuestionWithContext(
self.solve_parent_question = solve_parent_question if solve_parent_question else SolveQuestionWithContext(
self.llm, use_default_prompt_template, prompt_template_dir)
self.answer_atom_question = answer_atom_question if answer_atom_question else AnswerQuestionWithContext(
self.solve_atom_question = solve_atom_question if solve_atom_question else SolveQuestionWithContext(
self.llm, use_default_prompt_template, prompt_template_dir)
extra_info_fetch_tools = []
extra_info_fetch_tools.extend(self.answer_parent_question.get_extra_info_fetch_tools())
extra_info_fetch_tools.extend(self.answer_atom_question.get_extra_info_fetch_tools())
extra_info_fetch_tools.extend(self.solve_parent_question.get_extra_info_fetch_tools())
extra_info_fetch_tools.extend(self.solve_atom_question.get_extra_info_fetch_tools())

super().__init__(extra_info_fetch_tools, intermediate_process_tools)

async def rewrite_question_if_need(self, question: Question):
if len(question.dependencies) > 0:
await asyncio.create_task(self.is_question_deps_ready(question))
rewrited_question = self.rewrite_question.forward(question)
info_dict = {
'status': f'重写问题',
'log_info': f'原问题: {question.question}. 重写后的问题: {rewrited_question}\n{str(question)}',
}
if self.use_en_log:
info_dict = {
'status': f'Rewriting Question',
'log_info': f'Original Question: {question.question}. Rewrited Question: {rewrited_question}\n{str(question)}',
}
else:
info_dict = {
'status': f'重写问题',
'log_info': f'原问题: {question.question}. 重写后的问题: {rewrited_question}\n{str(question)}',
}
self.process_intermediate_info(info_dict)
current_question = Question(
rewrited_question,
Expand All @@ -62,10 +71,16 @@ async def rewrite_question_if_need(self, question: Question):
return question

async def solve_problem_impl(self, question: Question, **kwargs):
info_dict = {
'status': 'start solve_problem_impl',
'log_info': f'current question depth: {question.get_current_depth()}\n{str(question)}'
}
if self.use_en_log:
info_dict = {
'status': 'start solve_problem_impl',
'log_info': f'current question depth: {question.get_current_depth()}\n{str(question)}'
}
else:
info_dict = {
'status': '开始处理问题',
'log_info': f'当前问题深度: {question.get_current_depth()}\n{str(question)}'
}
self.process_intermediate_info(info_dict)

current_question = await self.rewrite_question_if_need(question)
Expand Down Expand Up @@ -110,11 +125,11 @@ async def solve_problem_impl(self, question: Question, **kwargs):
context=children_answers_context
)

answer = self.answer_parent_question.forward(parent_question)
answer = self.solve_parent_question.forward(parent_question)
return answer
else:
atom_question = Question(
question=current_question.question,
)
atom_answer = self.answer_atom_question.forward(atom_question)
atom_answer = self.solve_atom_question.forward(atom_question)
return atom_answer
103 changes: 0 additions & 103 deletions python/knext/knext/ca/logic/modules/extractor.py

This file was deleted.

119 changes: 42 additions & 77 deletions python/knext/knext/ca/logic/modules/planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,25 @@
from knext.ca.common.utils import logger


class DivideQuestion(KagBaseModule):
class Planner(KagBaseModule):
def __init__(self, llm_module, use_default_prompt_template, prompt_template_dir, is_prompt_template_cn):
super().__init__(
llm_module=llm_module,
use_default_prompt_template=use_default_prompt_template,
is_prompt_template_cn=is_prompt_template_cn,
prompt_template_dir=prompt_template_dir
)


class DivideQuestion(Planner):
"""
Module for dividing a question into serveral sub questions.

"""
def __init__(self, llm_module, use_default_prompt_template=True, prompt_template_dir=None):
super().__init__(llm_module, use_default_prompt_template, prompt_template_dir)

def __init__(self, llm_module, use_default_prompt_template=False, prompt_template_dir=None,
is_prompt_template_cn=True):
super().__init__(llm_module, use_default_prompt_template, prompt_template_dir, is_prompt_template_cn)

def get_module_name(self):
return "DivideQuestion"
Expand All @@ -32,35 +44,44 @@ def preprocess(self, question: Question):
def postprocess(self, question: Question, llm_output):
def _output_parse(_output_string):
# parse output
parts = _output_string.split("llm_output:") # 分割一次
parts = _output_string.split("llm_output:")
result = parts[-1] if len(parts) > 1 else _output_string
# parse \n
parts_2 = result.split('依赖关系是:')
if self.is_prompt_template_cn:
parts_2 = result.split('依赖关系是:')
else:
parts_2 = result.split('dependent relationship:')
qustion = parts_2[0].strip().split('\n')
dep = parts_2[1].strip().split('\n')
return qustion, dep

def _process_dep(_input_list):
dep_dict = {}
for dep in _input_list:
res = dep.split("依赖")
if self.is_prompt_template_cn:
res = dep.split("依赖")
else:
res = dep.split("deps")
assert len(res) == 2
key = res[0].strip()
dep = res[1].strip('"').split(",") if "," in res[1] else res[1].strip().split(",")
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a lot of string segmentation logic here, please give an example and add comments to explain in detail.

dep_real = [dep_i.strip() for dep_i in dep]
dep_dict[key] = dep_real
return dep_dict

result_string = _output_parse(llm_output)
sub_questions_list = []
sub_logic_forms_list = []
for org_question in result_string[0]:
sub_questions_list.append(org_question)
sub_dependencies = _process_dep(result_string[1])
qk_Q_map = {}

org_question_children = []
for q_ind, sub_question in enumerate(sub_questions_list):
q_key = f'问题{q_ind+1}'
if self.is_prompt_template_cn:
q_key = f'问题{q_ind + 1}'
else:
q_key = f'question{q_ind + 1}'
q_deps_Q_list = []
for q_dep in sub_dependencies[q_key]:
if q_dep == "None":
Expand All @@ -74,13 +95,15 @@ def _process_dep(_input_list):
return list(qk_Q_map.values())


class RewriteQuestionBasedOnDeps(KagBaseModule):
class RewriteQuestionBasedOnDeps(Planner):
"""
Module for rewriting a question based on the current question and dependent question

"""
def __init__(self, llm_module, use_default_prompt_template=True, prompt_template_dir=None):
super().__init__(llm_module, use_default_prompt_template, prompt_template_dir)

def __init__(self, llm_module, use_default_prompt_template=False, prompt_template_dir=None,
is_prompt_template_cn=True):
super().__init__(llm_module, use_default_prompt_template, prompt_template_dir, is_prompt_template_cn)

def get_module_name(self):
return "RewriteQuestionBasedOnDeps"
Expand All @@ -89,10 +112,14 @@ def get_template_var_names(self):
return ['question', 'context']

def preprocess(self, question: Question):
context_string = ''
context_string = ''
if len(question.dependencies) > 0:
for q in question.dependencies:
context_string += f"问题: {q.question} \n 答案: {q.answer}"+'\n'
if self.is_prompt_template_cn:
for q in question.dependencies:
context_string += f"问题: {q.question} \n 答案: {q.answer}" + '\n'
else:
for q in question.dependencies:
context_string += f"question: {q.question} \n answer: {q.answer}" + '\n'
prompt = self.state_dict['prompt_template'].substitute(
question=question.question,
context=context_string
Expand All @@ -107,65 +134,3 @@ def postprocess(self, question: Question, llm_output):
return result


class IsAtomQuestion(KagBaseModule):
"""
Module for determining if a question pertains to atomic concepts based on the input question.

"""
def __init__(self, llm, use_default_prompt_template, prompt_template_dir):
use_default_prompt_template = False
super().__init__(llm, use_default_prompt_template, prompt_template_dir)

def get_module_name(self):
return "IsAtomQuestion"

def get_template_var_names(self):
return ['question']

def preprocess(self, question: Question):
prompt = self.state_dict['prompt_template'].substitute(
question=question.question,
)
return prompt

def postprocess(self, question: Question, llm_output):
llm_output = llm_output.split(':')[-1].strip()
if llm_output == '是':
return True
elif llm_output == '否':
return False
else:
warning_reuslt = f'结果为:{llm_output}'
logger.warning(f'{warning_reuslt}')
return True


class DoesQuestionNeedExtraInfo(KagBaseModule):
"""
Module for determining if a question needs additional information based on the question.

"""
def __init__(self, llm_module, use_default_prompt_template=True, prompt_template_dir=None):
super().__init__(llm_module, use_default_prompt_template, prompt_template_dir)

def get_module_name(self):
return "DoesQuestionNeedExtraInfo"

def get_template_var_names(self):
return ['question']

def preprocess(self, question: Question):
prompt = self.state_dict['prompt_template'].substitute(
question=question.question
)
return prompt

def postprocess(self, question: Question, llm_output):
if llm_output == '是':
return True
elif llm_output == '否':
return False
else:
warning_reuslt = f'结果为:{llm_output}'
logger.debug(f'{warning_reuslt}')
return False
Loading
Loading