Skip to content

Commit

Permalink
Auto-sync-2024-04-29-21-51-09
Browse files Browse the repository at this point in the history
  • Loading branch information
shaoyijia authored Apr 30, 2024
1 parent 5c21acf commit bbbb6e1
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 10 deletions.
41 changes: 31 additions & 10 deletions src/storm_wiki/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from storm_wiki.modules.persona_generator import StormPersonaGenerator
from storm_wiki.modules.retriever import StormRetriever
from storm_wiki.modules.storm_dataclass import StormInformationTable, StormArticle
from utils import FileIOHelper
from utils import FileIOHelper, makeStringRed


class STORMWikiLMConfigs(LMConfigs):
Expand Down Expand Up @@ -246,6 +246,21 @@ def post_run(self):
call.pop('kwargs') # All kwargs are dumped together to run_config.json.
f.write(json.dumps(call) + '\n')

def _load_information_table_from_local_fs(self, information_table_local_path):
assert os.path.exists(information_table_local_path), makeStringRed(f"{information_table_local_path} not exists. Please set --do-research argument to prepare the conversation_log.json for this topic.")
return StormInformationTable.from_conversation_log_file(information_table_local_path)

def _load_outline_from_local_fs(self, topic, outline_local_path):
assert os.path.exists(outline_local_path), makeStringRed(f"{outline_local_path} not exists. Please set --do-generate-outline argument to prepare the storm_gen_outline.txt for this topic.")
return StormArticle.from_outline_file(topic=topic, file_path=outline_local_path)

def _load_draft_article_from_local_fs(self, topic, draft_article_path, url_to_info_path):
assert os.path.exists(draft_article_path), makeStringRed(f"{draft_article_path} not exists. Please set --do-generate-article argument to prepare the storm_gen_article.txt for this topic.")
assert os.path.exists(url_to_info_path), makeStringRed(f"{url_to_info_path} not exists. Please set --do-generate-article argument to prepare the url_to_info.json for this topic.")
article_text = FileIOHelper.load_str(draft_article_path)
references = FileIOHelper.load_json(url_to_info_path)
return StormArticle.from_string(topic_name=topic, article_text=article_text, references=references)

def run(self,
topic: str,
ground_truth_url: str = '',
Expand All @@ -272,6 +287,9 @@ def run(self,
remove_duplicate: If True, remove duplicated content.
callback_handler: A callback handler to handle the intermediate results.
"""
assert do_research or do_generate_outline or do_generate_article or do_polish_article, \
makeStringRed("No action is specified. Please set at least one of --do-research, --do-generate-outline, --do-generate-article, --do-polish-article")

self.topic = topic
self.article_dir_name = topic.replace(' ', '_').replace('/', '_')
self.article_output_dir = os.path.join(self.args.output_dir, self.article_dir_name)
Expand All @@ -282,27 +300,30 @@ def run(self,
if do_research:
information_table = self.run_knowledge_curation_module(ground_truth_url=ground_truth_url,
callback_handler=callback_handler)
else:
information_table = StormInformationTable.from_conversation_log_file(
os.path.join(self.article_output_dir, 'conversation_log.json'))

# outline generation module
outline: StormArticle = None
if do_generate_outline:
# load information table if it's not initialized
if information_table is None:
information_table = self._load_information_table_from_local_fs(os.path.join(self.article_output_dir, 'conversation_log.json'))
outline = self.run_outline_generation_module(information_table=information_table,
callback_handler=callback_handler)
else:
outline = StormArticle.from_outline_file(topic=topic, file_path=os.path.join(self.article_output_dir,
'storm_gen_outline.txt'))

# article generation module
draft_article: StormArticle = None
if do_generate_article:
if information_table is None:
information_table = self._load_information_table_from_local_fs(os.path.join(self.article_output_dir, 'conversation_log.json'))
if outline is None:
outline = self._load_outline_from_local_fs(topic=topic, outline_local_path=os.path.join(self.article_output_dir, 'storm_gen_outline.txt'))
draft_article = self.run_article_generation_module(outline=outline,
information_table=information_table,
callback_handler=callback_handler)

# article polishing module
if do_polish_article:
polished_article = self.run_article_polishing_module(draft_article=draft_article,
remove_duplicate=remove_duplicate)
if draft_article is None:
draft_article_path = os.path.join(self.article_output_dir, 'storm_gen_article.txt')
url_to_info_path = os.path.join(self.article_output_dir, 'url_to_info.json')
draft_article = self._load_draft_article_from_local_fs(topic=topic, draft_article_path=draft_article_path, url_to_info_path=url_to_info_path)
self.run_article_polishing_module(draft_article=draft_article, remove_duplicate=remove_duplicate)
2 changes: 2 additions & 0 deletions src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ def load_api_key(toml_file_path):
for key, value in data.items():
os.environ[key] = str(value)

def makeStringRed(message):
return f"\033[91m {message}\033[00m"

class ArticleTextProcessing:
@staticmethod
Expand Down

0 comments on commit bbbb6e1

Please sign in to comment.