Skip to content

Commit

Permalink
improve reference retrieval
Browse files Browse the repository at this point in the history
  • Loading branch information
tuhahaha authored and JianxinMa committed Oct 9, 2023
1 parent 6cd5427 commit 38c3019
Show file tree
Hide file tree
Showing 8 changed files with 98 additions and 12 deletions.
6 changes: 4 additions & 2 deletions qwen_agent/memory/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@ def get(self, query: str, records: list, llm=None, stream=False, max_token=5000)
search_agent = SimilaritySearch(type=self.ss_type, llm=llm, stream=stream)
_ref_list = []
for record in records:
_ref_list.append(search_agent.run(record, query))
now_ref_list = search_agent.run(record, query)
if now_ref_list['text']:
_ref_list.append(now_ref_list)

if _ref_list[0]['text'] == []:
if not _ref_list:
_ref_list = self.get_top(records)
# token number
new_ref_list = []
Expand Down
1 change: 1 addition & 0 deletions qwen_agent/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ class SimilaritySearchType(Enum):
KeyWord = 'keyword'
QueryMatch = 'querymatch'
LLM = 'llm'
Jaccard = 'jaccard'
12 changes: 8 additions & 4 deletions qwen_agent/tools/parse_doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,12 @@ def parse_html(htmltext):
return html2text.html2text(htmltext)


def replace_multiple_newlines(s):
return re.sub('\n+', '\n', s)
def pre_process_html(s):
# replace multiple newlines
s = re.sub('\n+', '\n', s)
# replace special string
s = s.replace('Add to Qwen\'s Reading List', '')
return s


def parse_html_bs(path, pre_gen_question=False):
Expand All @@ -45,8 +49,8 @@ def parse_html_bs(path, pre_gen_question=False):
res = []
for page in pages:
print(len(page.page_content.split(' ')))
res.append({'page_content': replace_multiple_newlines(page.page_content), 'metadata': page.metadata, 'related_questions': gen_q(page.page_content)})
res.append({'page_content': pre_process_html(page.page_content), 'metadata': page.metadata, 'related_questions': gen_q(page.page_content)})
else:
res = [{'page_content': replace_multiple_newlines(page.page_content), 'metadata': page.metadata} for page in pages]
res = [{'page_content': pre_process_html(page.page_content), 'metadata': page.metadata} for page in pages]

return res
3 changes: 3 additions & 0 deletions qwen_agent/tools/similarity_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ def __init__(self, type='keyword', llm=None, stream=False):
elif type == SimilaritySearchType.LLM.value:
module = 'qwen_agent.tools.similarity_search_llm'
run_func = importlib.import_module(module).SSLLM(llm).run
elif type == SimilaritySearchType.Jaccard.value:
module = 'qwen_agent.tools.similarity_search_jaccard'
run_func = importlib.import_module(module).SSJaccard(llm, stream).run
else:
raise NotImplementedError
self.run_func = run_func
Expand Down
63 changes: 63 additions & 0 deletions qwen_agent/tools/similarity_search_jaccard.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from qwen_agent.schema import RefMaterial
from qwen_agent.utils.util import get_split_word


class SSJaccard:
def __init__(self, llm=None, stream=False):
self.llm = llm
self.stream = stream

def run(self, line, query):
"""
Input: one line
Output: the relative text
"""
wordlist = get_split_word(query)

content = line['raw']
if isinstance(content, str):
content = content.split('\n')

res = []
sims = []
for i, page in enumerate(content):
sim = self.filter_section(page, wordlist)
sims.append([i, sim])
sims.sort(key=lambda x: x[1], reverse=True)
# print('sims: ', sims)
max_sims = sims[0][1]
if max_sims != 0:
for i, x in enumerate(sims):
if x[1] < max_sims and i > 3:
break
page = content[x[0]]
text = ''
if isinstance(page, str):
text = content[x[0]]
elif isinstance(page, dict):
text = page['page_content']
res.append(text)
# for x in res:
# print("=========")
# print(x)
return RefMaterial(url=line['url'], text=res).to_dict()

def filter_section(self, page, wordlist):
if isinstance(page, str):
text = page
elif isinstance(page, dict):
text = page['page_content']
else:
print(type(page))
raise TypeError

pagelist = get_split_word(text)
sim = self.jaccard_similarity(wordlist, pagelist)

return sim

def jaccard_similarity(self, list1, list2):
s1 = set(list1)
s2 = set(list2)
return len(s1.intersection(s2)) # avoid text length impact
# return len(s1.intersection(s2)) / len(s1.union(s2)) # jaccard similarity
19 changes: 16 additions & 3 deletions qwen_agent/utils/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
import sys
import traceback

import jieba
import json5
import requests
import tiktoken
from jieba import analyse

import tiktoken


def print_traceback():
print(''.join(traceback.format_exception(*sys.exc_info())))
Expand Down Expand Up @@ -43,12 +45,23 @@ def send_msg(url, msg):
return requests.post(url, params=msg)


def get_split_word(text):
text = text.lower()
_wordlist = jieba.lcut(text.strip())
wordlist = []
for x in _wordlist:
if x not in [' ', ' ', '\t', '\n', '\\', 'is', 'are', 'what', 'how', '的', '吗', '是', '了', '怎么', '如何', '什么', '?', '?', '!']:
wordlist.append(x)
# print('wordlist: ', wordlist)
return wordlist


def get_key_word(text):
# _wordlist = jieba.lcut(text.strip())
text = text.lower()
_wordlist = analyse.extract_tags(text)
wordlist = []
for x in _wordlist:
if x not in [' ', ' ', '\t', '\n', '\\', 'is', 'are', '的', '吗', '是', '了', '怎么', '如何', '什么', '?', '?', '!']:
if x not in [' ', ' ', '\t', '\n', '\\', 'is', 'are', 'what', 'how', '的', '吗', '是', '了', '怎么', '如何', '什么', '?', '?', '!']:
wordlist.append(x)
print('wordlist: ', wordlist)
return wordlist
Expand Down
2 changes: 1 addition & 1 deletion qwen_server/config_browserqwen.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

# using similarity search on reference material before answer
similarity_search = True # [True, False]
similarity_search_type = 'keyword' # ['keyword', 'querymatch', 'llm']
similarity_search_type = 'jaccard' # ['keyword', 'querymatch', 'llm', 'jaccard']


""" ===== main.py setting ===== """
Expand Down
4 changes: 2 additions & 2 deletions qwen_server/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,8 @@ def cache_data(data, cache_file):
tmp_html_file = os.path.join(config_browserqwen.cache_root, 'tmp.html')
save_text_to_file(tmp_html_file, data['content'])
data['content'] = parse_html_bs(tmp_html_file, pre_gen_question=config_browserqwen.pre_gen_question)
except Exception as ex:
print(ex)
except Exception:
print_traceback()
extract = data['content'][0]['metadata']['title']

today = datetime.date.today()
Expand Down

0 comments on commit 38c3019

Please sign in to comment.