From 85a1e6e4391059635109304c1f51fb0c486fe6bc Mon Sep 17 00:00:00 2001 From: zenith110 Date: Sun, 28 Jul 2024 10:32:14 -0400 Subject: [PATCH] Reverted back everything except requirements.txt, the new example file, and rm.py --- knowledge_storm/__init__.py | 2 +- knowledge_storm/interface.py | 69 ++-- knowledge_storm/lm.py | 228 +++++-------- knowledge_storm/storm_wiki/engine.py | 310 +++++++----------- .../storm_wiki/modules/article_generation.py | 117 +++---- .../storm_wiki/modules/article_polish.py | 46 +-- .../storm_wiki/modules/knowledge_curation.py | 249 +++++--------- .../storm_wiki/modules/outline_generation.py | 106 +++--- .../storm_wiki/modules/persona_generator.py | 64 ++-- .../storm_wiki/modules/retriever.py | 23 +- .../storm_wiki/modules/storm_dataclass.py | 233 +++++-------- knowledge_storm/utils.py | 166 ++++------ 12 files changed, 589 insertions(+), 1024 deletions(-) diff --git a/knowledge_storm/__init__.py b/knowledge_storm/__init__.py index 74dcabbe..f1fd18ea 100644 --- a/knowledge_storm/__init__.py +++ b/knowledge_storm/__init__.py @@ -1,5 +1,5 @@ from .storm_wiki.engine import ( STORMWikiLMConfigs, STORMWikiRunnerArguments, - STORMWikiRunner, + STORMWikiRunner ) diff --git a/knowledge_storm/interface.py b/knowledge_storm/interface.py index f6c11bd9..03df2fb6 100644 --- a/knowledge_storm/interface.py +++ b/knowledge_storm/interface.py @@ -5,9 +5,7 @@ from collections import OrderedDict from typing import Dict, List, Optional, Union -logging.basicConfig( - level=logging.INFO, format="%(name)s : %(levelname)-8s : %(message)s" -) +logging.basicConfig(level=logging.INFO, format='%(name)s : %(levelname)-8s : %(message)s') logger = logging.getLogger(__name__) @@ -72,9 +70,7 @@ class Article(ABC): def __init__(self, topic_name): self.root = ArticleSectionNode(topic_name) - def find_section( - self, node: ArticleSectionNode, name: str - ) -> Optional[ArticleSectionNode]: + def find_section(self, node: ArticleSectionNode, name: str) -> Optional[ArticleSectionNode]: """ Return the node of the section given the section name. @@ -156,9 +152,7 @@ def prune_empty_nodes(self, node=None): if node is None: node = self.root - node.children[:] = [ - child for child in node.children if self.prune_empty_nodes(child) - ] + node.children[:] = [child for child in node.children if self.prune_empty_nodes(child)] if (node.content is None or node.content == "") and not node.children: return None @@ -184,9 +178,7 @@ def update_search_top_k(self, k): def collect_and_reset_rm_usage(self): combined_usage = [] for attr_name in self.__dict__: - if "_rm" in attr_name and hasattr( - getattr(self, attr_name), "get_usage_and_reset" - ): + if '_rm' in attr_name and hasattr(getattr(self, attr_name), 'get_usage_and_reset'): combined_usage.append(getattr(self, attr_name).get_usage_and_reset()) name_to_usage = {} @@ -248,9 +240,7 @@ class OutlineGenerationModule(ABC): """ @abstractmethod - def generate_outline( - self, topic: str, information_table: InformationTable, **kwargs - ) -> Article: + def generate_outline(self, topic: str, information_table: InformationTable, **kwargs) -> Article: """ Generate outline for the article. Required arguments include: topic: the topic of interest @@ -273,13 +263,11 @@ class ArticleGenerationModule(ABC): """ @abstractmethod - def generate_article( - self, - topic: str, - information_table: InformationTable, - article_with_outline: Article, - **kwargs, - ) -> Article: + def generate_article(self, + topic: str, + information_table: InformationTable, + article_with_outline: Article, + **kwargs) -> Article: """ Generate article. Required arguments include: topic: the topic of interest @@ -324,15 +312,14 @@ def wrapper(self, *args, **kwargs): class LMConfigs(ABC): """Abstract base class for language model configurations of the knowledge curation engine. - The language model used for each part should be declared with a suffix '_lm' in the attribute name. - """ + The language model used for each part should be declared with a suffix '_lm' in the attribute name.""" def __init__(self): pass def init_check(self): for attr_name in self.__dict__: - if "_lm" in attr_name and getattr(self, attr_name) is None: + if '_lm' in attr_name and getattr(self, attr_name) is None: logging.warning( f"Language model for {attr_name} is not initialized. Please call set_{attr_name}()" ) @@ -340,7 +327,7 @@ def init_check(self): def collect_and_reset_lm_history(self): history = [] for attr_name in self.__dict__: - if "_lm" in attr_name and hasattr(getattr(self, attr_name), "history"): + if '_lm' in attr_name and hasattr(getattr(self, attr_name), 'history'): history.extend(getattr(self, attr_name).history) getattr(self, attr_name).history = [] @@ -349,9 +336,7 @@ def collect_and_reset_lm_history(self): def collect_and_reset_lm_usage(self): combined_usage = [] for attr_name in self.__dict__: - if "_lm" in attr_name and hasattr( - getattr(self, attr_name), "get_usage_and_reset" - ): + if '_lm' in attr_name and hasattr(getattr(self, attr_name), 'get_usage_and_reset'): combined_usage.append(getattr(self, attr_name).get_usage_and_reset()) model_name_to_usage = {} @@ -360,12 +345,8 @@ def collect_and_reset_lm_usage(self): if model_name not in model_name_to_usage: model_name_to_usage[model_name] = tokens else: - model_name_to_usage[model_name]["prompt_tokens"] += tokens[ - "prompt_tokens" - ] - model_name_to_usage[model_name]["completion_tokens"] += tokens[ - "completion_tokens" - ] + model_name_to_usage[model_name]['prompt_tokens'] += tokens['prompt_tokens'] + model_name_to_usage[model_name]['completion_tokens'] += tokens['completion_tokens'] return model_name_to_usage @@ -373,9 +354,8 @@ def log(self): return OrderedDict( { - attr_name: getattr(self, attr_name).kwargs - for attr_name in self.__dict__ - if "_lm" in attr_name and hasattr(getattr(self, attr_name), "kwargs") + attr_name: getattr(self, attr_name).kwargs for attr_name in self.__dict__ if + '_lm' in attr_name and hasattr(getattr(self, attr_name), 'kwargs') } ) @@ -399,21 +379,16 @@ def wrapper(*args, **kwargs): self.time[func.__name__] = execution_time logger.info(f"{func.__name__} executed in {execution_time:.4f} seconds") self.lm_cost[func.__name__] = self.lm_configs.collect_and_reset_lm_usage() - if hasattr(self, "retriever"): - self.rm_cost[func.__name__] = ( - self.retriever.collect_and_reset_rm_usage() - ) + if hasattr(self, 'retriever'): + self.rm_cost[func.__name__] = self.retriever.collect_and_reset_rm_usage() return result return wrapper def apply_decorators(self): """Apply decorators to methods that need them.""" - methods_to_decorate = [ - method_name - for method_name in dir(self) - if callable(getattr(self, method_name)) and method_name.startswith("run_") - ] + methods_to_decorate = [method_name for method_name in dir(self) + if callable(getattr(self, method_name)) and method_name.startswith('run_')] for method_name in methods_to_decorate: original_method = getattr(self, method_name) decorated_method = self.log_execution_time_and_lm_rm_usage(original_method) diff --git a/knowledge_storm/lm.py b/knowledge_storm/lm.py index 1aa34d24..e9c50852 100644 --- a/knowledge_storm/lm.py +++ b/knowledge_storm/lm.py @@ -9,10 +9,7 @@ import requests from dsp import ERRORS, backoff_hdlr, giveup_hdlr from dsp.modules.hf import openai_to_hf -from dsp.modules.hf_client import ( - send_hfvllm_request_v00, - send_hftgi_request_v01_wrapped, -) +from dsp.modules.hf_client import send_hfvllm_request_v00, send_hftgi_request_v01_wrapped from transformers import AutoTokenizer try: @@ -25,11 +22,11 @@ class OpenAIModel(dspy.OpenAI): """A wrapper class for dspy.OpenAI.""" def __init__( - self, - model: str = "gpt-3.5-turbo-instruct", - api_key: Optional[str] = None, - model_type: Literal["chat", "text"] = None, - **kwargs, + self, + model: str = "gpt-3.5-turbo-instruct", + api_key: Optional[str] = None, + model_type: Literal["chat", "text"] = None, + **kwargs ): super().__init__(model=model, api_key=api_key, model_type=model_type, **kwargs) self._token_usage_lock = threading.Lock() @@ -38,20 +35,17 @@ def __init__( def log_usage(self, response): """Log the total tokens from the OpenAI API response.""" - usage_data = response.get("usage") + usage_data = response.get('usage') if usage_data: with self._token_usage_lock: - self.prompt_tokens += usage_data.get("prompt_tokens", 0) - self.completion_tokens += usage_data.get("completion_tokens", 0) + self.prompt_tokens += usage_data.get('prompt_tokens', 0) + self.completion_tokens += usage_data.get('completion_tokens', 0) def get_usage_and_reset(self): """Get the total tokens used and reset the token usage.""" usage = { - self.kwargs.get("model") - or self.kwargs.get("engine"): { - "prompt_tokens": self.prompt_tokens, - "completion_tokens": self.completion_tokens, - } + self.kwargs.get('model') or self.kwargs.get('engine'): + {'prompt_tokens': self.prompt_tokens, 'completion_tokens': self.completion_tokens} } self.prompt_tokens = 0 self.completion_tokens = 0 @@ -59,11 +53,11 @@ def get_usage_and_reset(self): return usage def __call__( - self, - prompt: str, - only_completed: bool = True, - return_sorted: bool = False, - **kwargs, + self, + prompt: str, + only_completed: bool = True, + return_sorted: bool = False, + **kwargs, ) -> list[dict[str, Any]]: """Copied from dspy/dsp/modules/gpt3.py with the addition of tracking token usage.""" @@ -115,11 +109,11 @@ class DeepSeekModel(dspy.OpenAI): """A wrapper class for DeepSeek API, compatible with dspy.OpenAI.""" def __init__( - self, - model: str = "deepseek-chat", - api_key: Optional[str] = None, - api_base: str = "https://api.deepseek.com", - **kwargs, + self, + model: str = "deepseek-chat", + api_key: Optional[str] = None, + api_base: str = "https://api.deepseek.com", + **kwargs ): super().__init__(model=model, api_key=api_key, api_base=api_base, **kwargs) self._token_usage_lock = threading.Lock() @@ -129,25 +123,21 @@ def __init__( self.api_key = api_key or os.getenv("DEEPSEEK_API_KEY") self.api_base = api_base if not self.api_key: - raise ValueError( - "DeepSeek API key must be provided either as an argument or as an environment variable DEEPSEEK_API_KEY" - ) + raise ValueError("DeepSeek API key must be provided either as an argument or as an environment variable DEEPSEEK_API_KEY") def log_usage(self, response): """Log the total tokens from the DeepSeek API response.""" - usage_data = response.get("usage") + usage_data = response.get('usage') if usage_data: with self._token_usage_lock: - self.prompt_tokens += usage_data.get("prompt_tokens", 0) - self.completion_tokens += usage_data.get("completion_tokens", 0) + self.prompt_tokens += usage_data.get('prompt_tokens', 0) + self.completion_tokens += usage_data.get('completion_tokens', 0) def get_usage_and_reset(self): """Get the total tokens used and reset the token usage.""" usage = { - self.model: { - "prompt_tokens": self.prompt_tokens, - "completion_tokens": self.completion_tokens, - } + self.model: + {'prompt_tokens': self.prompt_tokens, 'completion_tokens': self.completion_tokens} } self.prompt_tokens = 0 self.completion_tokens = 0 @@ -164,25 +154,23 @@ def _create_completion(self, prompt: str, **kwargs): """Create a completion using the DeepSeek API.""" headers = { "Content-Type": "application/json", - "Authorization": f"Bearer {self.api_key}", + "Authorization": f"Bearer {self.api_key}" } data = { "model": self.model, "messages": [{"role": "user", "content": prompt}], - **kwargs, + **kwargs } - response = requests.post( - f"{self.api_base}/v1/chat/completions", headers=headers, json=data - ) + response = requests.post(f"{self.api_base}/v1/chat/completions", headers=headers, json=data) response.raise_for_status() return response.json() def __call__( - self, - prompt: str, - only_completed: bool = True, - return_sorted: bool = False, - **kwargs, + self, + prompt: str, + only_completed: bool = True, + return_sorted: bool = False, + **kwargs, ) -> list[dict[str, Any]]: """Call the DeepSeek API to generate completions.""" assert only_completed, "for now" @@ -208,46 +196,35 @@ def __call__( class AzureOpenAIModel(dspy.AzureOpenAI): """A wrapper class for dspy.AzureOpenAI.""" - def __init__( - self, - api_base: Optional[str] = None, - api_version: Optional[str] = None, - model: str = "gpt-3.5-turbo-instruct", - api_key: Optional[str] = None, - model_type: Literal["chat", "text"] = "chat", - **kwargs, + self, + api_base: Optional[str] = None, + api_version: Optional[str] = None, + model: str = "gpt-3.5-turbo-instruct", + api_key: Optional[str] = None, + model_type: Literal["chat", "text"] = "chat", + **kwargs, ): super().__init__( - api_base=api_base, - api_version=api_version, - model=model, - api_key=api_key, - model_type=model_type, - **kwargs, - ) + api_base=api_base, api_version=api_version, model=model, api_key=api_key, model_type=model_type, **kwargs) self._token_usage_lock = threading.Lock() self.prompt_tokens = 0 self.completion_tokens = 0 def log_usage(self, response): """Log the total tokens from the OpenAI API response. - Override log_usage() in dspy.AzureOpenAI for tracking accumulated token usage. - """ - usage_data = response.get("usage") + Override log_usage() in dspy.AzureOpenAI for tracking accumulated token usage.""" + usage_data = response.get('usage') if usage_data: with self._token_usage_lock: - self.prompt_tokens += usage_data.get("prompt_tokens", 0) - self.completion_tokens += usage_data.get("completion_tokens", 0) + self.prompt_tokens += usage_data.get('prompt_tokens', 0) + self.completion_tokens += usage_data.get('completion_tokens', 0) def get_usage_and_reset(self): """Get the total tokens used and reset the token usage.""" usage = { - self.kwargs.get("model") - or self.kwargs.get("engine"): { - "prompt_tokens": self.prompt_tokens, - "completion_tokens": self.completion_tokens, - } + self.kwargs.get('model') or self.kwargs.get('engine'): + {'prompt_tokens': self.prompt_tokens, 'completion_tokens': self.completion_tokens} } self.prompt_tokens = 0 self.completion_tokens = 0 @@ -259,11 +236,11 @@ class ClaudeModel(dspy.dsp.modules.lm.LM): """Copied from dspy/dsp/modules/anthropic.py with the addition of tracking token usage.""" def __init__( - self, - model: str, - api_key: Optional[str] = None, - api_base: Optional[str] = None, - **kwargs, + self, + model: str, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + **kwargs, ): super().__init__(model) try: @@ -272,21 +249,12 @@ def __init__( raise ImportError("Claude requires `pip install anthropic`.") from err self.provider = "anthropic" - self.api_key = api_key = ( - os.environ.get("ANTHROPIC_API_KEY") if api_key is None else api_key - ) - self.api_base = ( - "https://api.anthropic.com/v1/messages" if api_base is None else api_base - ) - self.kwargs = { - "temperature": kwargs.get("temperature", 0.0), - "max_tokens": min(kwargs.get("max_tokens", 4096), 4096), - "top_p": kwargs.get("top_p", 1.0), - "top_k": kwargs.get("top_k", 1), - "n": kwargs.pop("n", kwargs.pop("num_generations", 1)), - **kwargs, - "model": model, - } + self.api_key = api_key = os.environ.get("ANTHROPIC_API_KEY") if api_key is None else api_key + self.api_base = "https://api.anthropic.com/v1/messages" if api_base is None else api_base + self.kwargs = {"temperature": kwargs.get("temperature", 0.0), + "max_tokens": min(kwargs.get("max_tokens", 4096), 4096), "top_p": kwargs.get("top_p", 1.0), + "top_k": kwargs.get("top_k", 1), "n": kwargs.pop("n", kwargs.pop("num_generations", 1)), + **kwargs, "model": model} self.history: list[dict[str, Any]] = [] self.client = Anthropic(api_key=api_key) self.model = model @@ -306,10 +274,8 @@ def log_usage(self, response): def get_usage_and_reset(self): """Get the total tokens used and reset the token usage.""" usage = { - self.model: { - "prompt_tokens": self.prompt_tokens, - "completion_tokens": self.completion_tokens, - } + self.model: + {'prompt_tokens': self.prompt_tokens, 'completion_tokens': self.completion_tokens} } self.prompt_tokens = 0 self.completion_tokens = 0 @@ -341,7 +307,7 @@ def basic_request(self, prompt: str, **kwargs): "usage": { "input_tokens": response.usage.input_tokens, "output_tokens": response.usage.output_tokens, - }, + } }, "kwargs": kwargs, "raw_kwargs": raw_kwargs, @@ -411,7 +377,10 @@ def _generate(self, prompt, **kwargs): # "max_tokens": kwargs["max_tokens"], # "temperature": kwargs["temperature"], # } - payload = {"prompt": prompt, **kwargs} + payload = { + "prompt": prompt, + **kwargs + } response = send_hfvllm_request_v00( f"{self.url}/v1/completions", @@ -444,17 +413,11 @@ def __init__(self, model, port, url="http://localhost", **kwargs): super().__init__(model=model, base_url=f"{url}:{port}", **kwargs) # Store additional kwargs for the generate method. self.kwargs = {**self.kwargs, **kwargs} - + class TGIClient(dspy.HFClientTGI): def __init__(self, model, port, url, http_request_kwargs=None, **kwargs): - super().__init__( - model=model, - port=port, - url=url, - http_request_kwargs=http_request_kwargs, - **kwargs, - ) + super().__init__(model=model, port=port, url=url, http_request_kwargs=http_request_kwargs, **kwargs) def _generate(self, prompt, **kwargs): """Copied from dspy/dsp/modules/hf_client.py with the addition of removing hard-coded parameters.""" @@ -493,8 +456,8 @@ def _generate(self, prompt, **kwargs): completions = [json_response["generated_text"]] if ( - "details" in json_response - and "best_of_sequences" in json_response["details"] + "details" in json_response + and "best_of_sequences" in json_response["details"] ): completions += [ x["generated_text"] @@ -511,22 +474,13 @@ def _generate(self, prompt, **kwargs): class TogetherClient(dspy.HFModel): """A wrapper class for dspy.Together.""" - def __init__( - self, - model, - apply_tokenizer_chat_template=False, - hf_tokenizer_name=None, - **kwargs, - ): + def __init__(self, model, apply_tokenizer_chat_template=False, hf_tokenizer_name=None, **kwargs): """Copied from dspy/dsp/modules/hf_client.py with the support of applying tokenizer chat template.""" super().__init__(model=model, is_client=True) self.session = requests.Session() - self.api_base = ( - "https://api.together.xyz/v1/completions" - if os.getenv("TOGETHER_API_BASE") is None - else os.getenv("TOGETHER_API_BASE") - ) + self.api_base = "https://api.together.xyz/v1/completions" if os.getenv( + "TOGETHER_API_BASE") is None else os.getenv("TOGETHER_API_BASE") self.token = os.getenv("TOGETHER_API_KEY") self.model = model @@ -538,9 +492,7 @@ def __init__( logging.info("Loading huggingface tokenizer.") if hf_tokenizer_name is None: hf_tokenizer_name = self.model - self.tokenizer = AutoTokenizer.from_pretrained( - hf_tokenizer_name, cache_dir=kwargs.get("cache_dir", None) - ) + self.tokenizer = AutoTokenizer.from_pretrained(hf_tokenizer_name, cache_dir=kwargs.get("cache_dir", None)) stop_default = "\n\n---" @@ -560,19 +512,17 @@ def __init__( def log_usage(self, response): """Log the total tokens from the OpenAI API response.""" - usage_data = response.get("usage") + usage_data = response.get('usage') if usage_data: with self._token_usage_lock: - self.prompt_tokens += usage_data.get("prompt_tokens", 0) - self.completion_tokens += usage_data.get("completion_tokens", 0) + self.prompt_tokens += usage_data.get('prompt_tokens', 0) + self.completion_tokens += usage_data.get('completion_tokens', 0) def get_usage_and_reset(self): """Get the total tokens used and reset the token usage.""" usage = { - self.model: { - "prompt_tokens": self.prompt_tokens, - "completion_tokens": self.completion_tokens, - } + self.model: + {'prompt_tokens': self.prompt_tokens, 'completion_tokens': self.completion_tokens} } self.prompt_tokens = 0 self.completion_tokens = 0 @@ -597,18 +547,14 @@ def _generate(self, prompt, use_chat_api=False, **kwargs): top_k = kwargs.get("top_k", 50) repetition_penalty = kwargs.get("repetition_penalty", 1) if self.apply_tokenizer_chat_template: - prompt = self.tokenizer.apply_chat_template( - [{"role": "user", "content": prompt}], tokenize=False - ) + prompt = self.tokenizer.apply_chat_template([{"role": "user", "content": prompt}], tokenize=False) # prompt = f"[INST]{prompt}[/INST]" if self.use_inst_template else prompt if use_chat_api: url = f"{self.api_base}/chat/completions" messages = [ - { - "role": "system", - "content": "You are a helpful assistant. You must continue the user text directly without *any* additional interjections.", - }, + {"role": "system", + "content": "You are a helpful assistant. You must continue the user text directly without *any* additional interjections."}, {"role": "user", "content": prompt}, ] body = { @@ -641,13 +587,9 @@ def _generate(self, prompt, use_chat_api=False, **kwargs): self.log_usage(resp_json) if use_chat_api: # completions = [resp_json['output'].get('choices', [])[0].get('message', {}).get('content', "")] - completions = [ - resp_json.get("choices", [])[0] - .get("message", {}) - .get("content", "") - ] + completions = [resp_json.get('choices', [])[0].get('message', {}).get('content', "")] else: # completions = [resp_json['output'].get('choices', [])[0].get('text', "")] - completions = [resp_json.get("choices", [])[0].get("text", "")] + completions = [resp_json.get('choices', [])[0].get('text', "")] response = {"prompt": prompt, "choices": [{"text": c} for c in completions]} return response diff --git a/knowledge_storm/storm_wiki/engine.py b/knowledge_storm/storm_wiki/engine.py index 746a07b0..e0c8dfcc 100644 --- a/knowledge_storm/storm_wiki/engine.py +++ b/knowledge_storm/storm_wiki/engine.py @@ -28,52 +28,43 @@ class STORMWikiLMConfigs(LMConfigs): """ def __init__(self): - self.conv_simulator_lm = ( - None # LLM used in conversation simulator except for question asking. - ) + self.conv_simulator_lm = None # LLM used in conversation simulator except for question asking. self.question_asker_lm = None # LLM used in question asking. self.outline_gen_lm = None # LLM used in outline generation. self.article_gen_lm = None # LLM used in article generation. self.article_polish_lm = None # LLM used in article polishing. def init_openai_model( - self, - openai_api_key: str, - openai_type: Literal["openai", "azure"], - api_base: Optional[str] = None, - api_version: Optional[str] = None, - temperature: Optional[float] = 1.0, - top_p: Optional[float] = 0.9, + self, + openai_api_key: str, + openai_type: Literal["openai", "azure"], + api_base: Optional[str] = None, + api_version: Optional[str] = None, + temperature: Optional[float] = 1.0, + top_p: Optional[float] = 0.9 ): """Legacy: Corresponding to the original setup in the NAACL'24 paper.""" openai_kwargs = { - "api_key": openai_api_key, - "api_provider": openai_type, - "temperature": temperature, - "top_p": top_p, - "api_base": None, + 'api_key': openai_api_key, + 'api_provider': openai_type, + 'temperature': temperature, + 'top_p': top_p, + 'api_base': None } - if openai_type and openai_type == "openai": - self.conv_simulator_lm = OpenAIModel( - model="gpt-3.5-turbo-instruct", max_tokens=500, **openai_kwargs - ) - self.question_asker_lm = OpenAIModel( - model="gpt-3.5-turbo", max_tokens=500, **openai_kwargs - ) + if openai_type and openai_type == 'openai': + self.conv_simulator_lm = OpenAIModel(model='gpt-3.5-turbo-instruct', + max_tokens=500, **openai_kwargs) + self.question_asker_lm = OpenAIModel(model='gpt-3.5-turbo', + max_tokens=500, **openai_kwargs) # 1/12/2024: Update gpt-4 to gpt-4-1106-preview. (Currently keep the original setup when using azure.) - self.outline_gen_lm = OpenAIModel( - model="gpt-4-0125-preview", max_tokens=400, **openai_kwargs - ) - self.article_gen_lm = OpenAIModel( - model="gpt-4o-2024-05-13", max_tokens=700, **openai_kwargs - ) - self.article_polish_lm = OpenAIModel( - model="gpt-4o-2024-05-13", max_tokens=4000, **openai_kwargs - ) + self.outline_gen_lm = OpenAIModel(model='gpt-4-0125-preview', + max_tokens=400, **openai_kwargs) + self.article_gen_lm = OpenAIModel(model='gpt-4o-2024-05-13', + max_tokens=700, **openai_kwargs) + self.article_polish_lm = OpenAIModel(model='gpt-4o-2024-05-13', + max_tokens=4000, **openai_kwargs) else: - logging.warning( - "No valid OpenAI API provider is provided. Cannot use default LLM configurations." - ) + logging.warning('No valid OpenAI API provider is provided. Cannot use default LLM configurations.') def set_conv_simulator_lm(self, model: Union[dspy.dsp.LM, dspy.dsp.HFModel]): self.conv_simulator_lm = model @@ -94,21 +85,16 @@ def set_article_polish_lm(self, model: Union[dspy.dsp.LM, dspy.dsp.HFModel]): @dataclass class STORMWikiRunnerArguments: """Arguments for controlling the STORM Wiki pipeline.""" - output_dir: str = field( metadata={"help": "Output directory for the results."}, ) max_conv_turn: int = field( default=3, - metadata={ - "help": "Maximum number of questions in conversational question asking." - }, + metadata={"help": "Maximum number of questions in conversational question asking."}, ) max_perspective: int = field( default=3, - metadata={ - "help": "Maximum number of perspectives to consider in perspective-guided question asking." - }, + metadata={"help": "Maximum number of perspectives to consider in perspective-guided question asking."}, ) max_search_queries_per_turn: int = field( default=3, @@ -128,27 +114,24 @@ class STORMWikiRunnerArguments: ) max_thread_num: int = field( default=10, - metadata={ - "help": "Maximum number of threads to use. " - "Consider reducing it if keep getting 'Exceed rate limit' error when calling LM API." - }, + metadata={"help": "Maximum number of threads to use. " + "Consider reducing it if keep getting 'Exceed rate limit' error when calling LM API."}, ) class STORMWikiRunner(Engine): """STORM Wiki pipeline runner.""" - def __init__( - self, args: STORMWikiRunnerArguments, lm_configs: STORMWikiLMConfigs, rm - ): + def __init__(self, + args: STORMWikiRunnerArguments, + lm_configs: STORMWikiLMConfigs, + rm): super().__init__(lm_configs=lm_configs) self.args = args self.lm_configs = lm_configs self.retriever = StormRetriever(rm=rm, k=self.args.retrieve_top_k) - storm_persona_generator = StormPersonaGenerator( - self.lm_configs.question_asker_lm - ) + storm_persona_generator = StormPersonaGenerator(self.lm_configs.question_asker_lm) self.storm_knowledge_curation_module = StormKnowledgeCurationModule( retriever=self.retriever, persona_generator=storm_persona_generator, @@ -157,7 +140,7 @@ def __init__( max_search_queries_per_turn=self.args.max_search_queries_per_turn, search_top_k=self.args.search_top_k, max_conv_turn=self.args.max_conv_turn, - max_thread_num=self.args.max_thread_num, + max_thread_num=self.args.max_thread_num ) self.storm_outline_generation_module = StormOutlineGenerationModule( outline_gen_lm=self.lm_configs.outline_gen_lm @@ -165,96 +148,73 @@ def __init__( self.storm_article_generation = StormArticleGenerationModule( article_gen_lm=self.lm_configs.article_gen_lm, retrieve_top_k=self.args.retrieve_top_k, - max_thread_num=self.args.max_thread_num, + max_thread_num=self.args.max_thread_num ) self.storm_article_polishing_module = StormArticlePolishingModule( article_gen_lm=self.lm_configs.article_gen_lm, - article_polish_lm=self.lm_configs.article_polish_lm, + article_polish_lm=self.lm_configs.article_polish_lm ) self.lm_configs.init_check() self.apply_decorators() - def run_knowledge_curation_module( - self, - ground_truth_url: str = "None", - callback_handler: BaseCallbackHandler = None, - ) -> StormInformationTable: - - information_table, conversation_log = ( - self.storm_knowledge_curation_module.research( - topic=self.topic, - ground_truth_url=ground_truth_url, - callback_handler=callback_handler, - max_perspective=self.args.max_perspective, - disable_perspective=False, - return_conversation_log=True, - ) - ) + def run_knowledge_curation_module(self, + ground_truth_url: str = "None", + callback_handler: BaseCallbackHandler = None) -> StormInformationTable: - FileIOHelper.dump_json( - conversation_log, - os.path.join(self.article_output_dir, "conversation_log.json"), - ) - information_table.dump_url_to_info( - os.path.join(self.article_output_dir, "raw_search_results.json") + information_table, conversation_log = self.storm_knowledge_curation_module.research( + topic=self.topic, + ground_truth_url=ground_truth_url, + callback_handler=callback_handler, + max_perspective=self.args.max_perspective, + disable_perspective=False, + return_conversation_log=True ) + + FileIOHelper.dump_json(conversation_log, os.path.join(self.article_output_dir, 'conversation_log.json')) + information_table.dump_url_to_info(os.path.join(self.article_output_dir, 'raw_search_results.json')) return information_table - def run_outline_generation_module( - self, - information_table: StormInformationTable, - callback_handler: BaseCallbackHandler = None, - ) -> StormArticle: + def run_outline_generation_module(self, + information_table: StormInformationTable, + callback_handler: BaseCallbackHandler = None) -> StormArticle: outline, draft_outline = self.storm_outline_generation_module.generate_outline( topic=self.topic, information_table=information_table, return_draft_outline=True, - callback_handler=callback_handler, - ) - outline.dump_outline_to_file( - os.path.join(self.article_output_dir, "storm_gen_outline.txt") - ) - draft_outline.dump_outline_to_file( - os.path.join(self.article_output_dir, "direct_gen_outline.txt") + callback_handler=callback_handler ) + outline.dump_outline_to_file(os.path.join(self.article_output_dir, 'storm_gen_outline.txt')) + draft_outline.dump_outline_to_file(os.path.join(self.article_output_dir, "direct_gen_outline.txt")) return outline - def run_article_generation_module( - self, - outline: StormArticle, - information_table=StormInformationTable, - callback_handler: BaseCallbackHandler = None, - ) -> StormArticle: + def run_article_generation_module(self, + outline: StormArticle, + information_table=StormInformationTable, + callback_handler: BaseCallbackHandler = None) -> StormArticle: draft_article = self.storm_article_generation.generate_article( topic=self.topic, information_table=information_table, article_with_outline=outline, - callback_handler=callback_handler, - ) - draft_article.dump_article_as_plain_text( - os.path.join(self.article_output_dir, "storm_gen_article.txt") - ) - draft_article.dump_reference_to_file( - os.path.join(self.article_output_dir, "url_to_info.json") + callback_handler=callback_handler ) + draft_article.dump_article_as_plain_text(os.path.join(self.article_output_dir, 'storm_gen_article.txt')) + draft_article.dump_reference_to_file(os.path.join(self.article_output_dir, 'url_to_info.json')) return draft_article - def run_article_polishing_module( - self, draft_article: StormArticle, remove_duplicate: bool = False - ) -> StormArticle: + def run_article_polishing_module(self, + draft_article: StormArticle, + remove_duplicate: bool = False) -> StormArticle: polished_article = self.storm_article_polishing_module.polish_article( topic=self.topic, draft_article=draft_article, - remove_duplicate=remove_duplicate, - ) - FileIOHelper.write_str( - polished_article.to_string(), - os.path.join(self.article_output_dir, "storm_gen_article_polished.txt"), + remove_duplicate=remove_duplicate ) + FileIOHelper.write_str(polished_article.to_string(), + os.path.join(self.article_output_dir, 'storm_gen_article_polished.txt')) return polished_article def post_run(self): @@ -264,61 +224,43 @@ def post_run(self): 2. Dumping the LLM call history. """ config_log = self.lm_configs.log() - FileIOHelper.dump_json( - config_log, os.path.join(self.article_output_dir, "run_config.json") - ) + FileIOHelper.dump_json(config_log, os.path.join(self.article_output_dir, 'run_config.json')) llm_call_history = self.lm_configs.collect_and_reset_lm_history() - with open( - os.path.join(self.article_output_dir, "llm_call_history.jsonl"), "w" - ) as f: + with open(os.path.join(self.article_output_dir, 'llm_call_history.jsonl'), 'w') as f: for call in llm_call_history: - if "kwargs" in call: - call.pop( - "kwargs" - ) # All kwargs are dumped together to run_config.json. - f.write(json.dumps(call) + "\n") + if 'kwargs' in call: + call.pop('kwargs') # All kwargs are dumped together to run_config.json. + f.write(json.dumps(call) + '\n') def _load_information_table_from_local_fs(self, information_table_local_path): assert os.path.exists(information_table_local_path), makeStringRed( - f"{information_table_local_path} not exists. Please set --do-research argument to prepare the conversation_log.json for this topic." - ) - return StormInformationTable.from_conversation_log_file( - information_table_local_path - ) + f"{information_table_local_path} not exists. Please set --do-research argument to prepare the conversation_log.json for this topic.") + return StormInformationTable.from_conversation_log_file(information_table_local_path) def _load_outline_from_local_fs(self, topic, outline_local_path): assert os.path.exists(outline_local_path), makeStringRed( - f"{outline_local_path} not exists. Please set --do-generate-outline argument to prepare the storm_gen_outline.txt for this topic." - ) + f"{outline_local_path} not exists. Please set --do-generate-outline argument to prepare the storm_gen_outline.txt for this topic.") return StormArticle.from_outline_file(topic=topic, file_path=outline_local_path) - def _load_draft_article_from_local_fs( - self, topic, draft_article_path, url_to_info_path - ): + def _load_draft_article_from_local_fs(self, topic, draft_article_path, url_to_info_path): assert os.path.exists(draft_article_path), makeStringRed( - f"{draft_article_path} not exists. Please set --do-generate-article argument to prepare the storm_gen_article.txt for this topic." - ) + f"{draft_article_path} not exists. Please set --do-generate-article argument to prepare the storm_gen_article.txt for this topic.") assert os.path.exists(url_to_info_path), makeStringRed( - f"{url_to_info_path} not exists. Please set --do-generate-article argument to prepare the url_to_info.json for this topic." - ) + f"{url_to_info_path} not exists. Please set --do-generate-article argument to prepare the url_to_info.json for this topic.") article_text = FileIOHelper.load_str(draft_article_path) references = FileIOHelper.load_json(url_to_info_path) - return StormArticle.from_string( - topic_name=topic, article_text=article_text, references=references - ) - - def run( - self, - topic: str, - ground_truth_url: str = "", - do_research: bool = True, - do_generate_outline: bool = True, - do_generate_article: bool = True, - do_polish_article: bool = True, - remove_duplicate: bool = False, - callback_handler: BaseCallbackHandler = BaseCallbackHandler(), - ): + return StormArticle.from_string(topic_name=topic, article_text=article_text, references=references) + + def run(self, + topic: str, + ground_truth_url: str = '', + do_research: bool = True, + do_generate_outline: bool = True, + do_generate_article: bool = True, + do_polish_article: bool = True, + remove_duplicate: bool = False, + callback_handler: BaseCallbackHandler = BaseCallbackHandler()): """ Run the STORM pipeline. @@ -336,74 +278,50 @@ def run( remove_duplicate: If True, remove duplicated content. callback_handler: A callback handler to handle the intermediate results. """ - assert ( - do_research - or do_generate_outline - or do_generate_article - or do_polish_article - ), makeStringRed( - "No action is specified. Please set at least one of --do-research, --do-generate-outline, --do-generate-article, --do-polish-article" - ) + assert do_research or do_generate_outline or do_generate_article or do_polish_article, \ + makeStringRed( + "No action is specified. Please set at least one of --do-research, --do-generate-outline, --do-generate-article, --do-polish-article") self.topic = topic - self.article_dir_name = topic.replace(" ", "_").replace("/", "_") - self.article_output_dir = os.path.join( - self.args.output_dir, self.article_dir_name - ) + self.article_dir_name = topic.replace(' ', '_').replace('/', '_') + self.article_output_dir = os.path.join(self.args.output_dir, self.article_dir_name) os.makedirs(self.article_output_dir, exist_ok=True) # research module information_table: StormInformationTable = None if do_research: - information_table = self.run_knowledge_curation_module( - ground_truth_url=ground_truth_url, callback_handler=callback_handler - ) + information_table = self.run_knowledge_curation_module(ground_truth_url=ground_truth_url, + callback_handler=callback_handler) # outline generation module outline: StormArticle = None if do_generate_outline: # load information table if it's not initialized if information_table is None: information_table = self._load_information_table_from_local_fs( - os.path.join(self.article_output_dir, "conversation_log.json") - ) - outline = self.run_outline_generation_module( - information_table=information_table, callback_handler=callback_handler - ) + os.path.join(self.article_output_dir, 'conversation_log.json')) + outline = self.run_outline_generation_module(information_table=information_table, + callback_handler=callback_handler) # article generation module draft_article: StormArticle = None if do_generate_article: if information_table is None: information_table = self._load_information_table_from_local_fs( - os.path.join(self.article_output_dir, "conversation_log.json") - ) + os.path.join(self.article_output_dir, 'conversation_log.json')) if outline is None: - outline = self._load_outline_from_local_fs( - topic=topic, - outline_local_path=os.path.join( - self.article_output_dir, "storm_gen_outline.txt" - ), - ) - draft_article = self.run_article_generation_module( - outline=outline, - information_table=information_table, - callback_handler=callback_handler, - ) + outline = self._load_outline_from_local_fs(topic=topic, + outline_local_path=os.path.join(self.article_output_dir, + 'storm_gen_outline.txt')) + draft_article = self.run_article_generation_module(outline=outline, + information_table=information_table, + callback_handler=callback_handler) # article polishing module if do_polish_article: if draft_article is None: - draft_article_path = os.path.join( - self.article_output_dir, "storm_gen_article.txt" - ) - url_to_info_path = os.path.join( - self.article_output_dir, "url_to_info.json" - ) - draft_article = self._load_draft_article_from_local_fs( - topic=topic, - draft_article_path=draft_article_path, - url_to_info_path=url_to_info_path, - ) - self.run_article_polishing_module( - draft_article=draft_article, remove_duplicate=remove_duplicate - ) + draft_article_path = os.path.join(self.article_output_dir, 'storm_gen_article.txt') + url_to_info_path = os.path.join(self.article_output_dir, 'url_to_info.json') + draft_article = self._load_draft_article_from_local_fs(topic=topic, + draft_article_path=draft_article_path, + url_to_info_path=url_to_info_path) + self.run_article_polishing_module(draft_article=draft_article, remove_duplicate=remove_duplicate) diff --git a/knowledge_storm/storm_wiki/modules/article_generation.py b/knowledge_storm/storm_wiki/modules/article_generation.py index 2e711465..a114b3ec 100644 --- a/knowledge_storm/storm_wiki/modules/article_generation.py +++ b/knowledge_storm/storm_wiki/modules/article_generation.py @@ -15,48 +15,35 @@ class StormArticleGenerationModule(ArticleGenerationModule): """ The interface for article generation stage. Given topic, collected information from - knowledge curation stage, generated outline from outline generation stage, + knowledge curation stage, generated outline from outline generation stage, """ - def __init__( - self, - article_gen_lm=Union[dspy.dsp.LM, dspy.dsp.HFModel], - retrieve_top_k: int = 5, - max_thread_num: int = 10, - ): + def __init__(self, + article_gen_lm=Union[dspy.dsp.LM, dspy.dsp.HFModel], + retrieve_top_k: int = 5, + max_thread_num: int = 10): super().__init__() self.retrieve_top_k = retrieve_top_k self.article_gen_lm = article_gen_lm self.max_thread_num = max_thread_num self.section_gen = ConvToSection(engine=self.article_gen_lm) - def generate_section( - self, topic, section_name, information_table, section_outline, section_query - ): + def generate_section(self, topic, section_name, information_table, section_outline, section_query): collected_info: List[StormInformation] = [] if information_table is not None: - collected_info = information_table.retrieve_information( - queries=section_query, search_top_k=self.retrieve_top_k - ) - output = self.section_gen( - topic=topic, - outline=section_outline, - section=section_name, - collected_info=collected_info, - ) - return { - "section_name": section_name, - "section_content": output.section, - "collected_info": collected_info, - } - - def generate_article( - self, - topic: str, - information_table: StormInformationTable, - article_with_outline: StormArticle, - callback_handler: BaseCallbackHandler = None, - ) -> StormArticle: + collected_info = information_table.retrieve_information(queries=section_query, + search_top_k=self.retrieve_top_k) + output = self.section_gen(topic=topic, + outline=section_outline, + section=section_name, + collected_info=collected_info) + return {"section_name": section_name, "section_content": output.section, "collected_info": collected_info} + + def generate_article(self, + topic: str, + information_table: StormInformationTable, + article_with_outline: StormArticle, + callback_handler: BaseCallbackHandler = None) -> StormArticle: """ Generate article for the topic based on the information table and article outline. @@ -76,48 +63,35 @@ def generate_article( section_output_dict_collection = [] if len(sections_to_write) == 0: - logging.error( - f"No outline for {topic}. Will directly search with the topic." - ) + logging.error(f'No outline for {topic}. Will directly search with the topic.') section_output_dict = self.generate_section( topic=topic, section_name=topic, information_table=information_table, section_outline="", - section_query=[topic], + section_query=[topic] ) section_output_dict_collection = [section_output_dict] else: - with concurrent.futures.ThreadPoolExecutor( - max_workers=self.max_thread_num - ) as executor: + with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_thread_num) as executor: future_to_sec_title = {} for section_title in sections_to_write: # We don't want to write a separate introduction section. - if section_title.lower().strip() == "introduction": + if section_title.lower().strip() == 'introduction': continue # We don't want to write a separate conclusion section. if section_title.lower().strip().startswith( - "conclusion" - ) or section_title.lower().strip().startswith("summary"): + 'conclusion') or section_title.lower().strip().startswith('summary'): continue - section_query = article_with_outline.get_outline_as_list( - root_section_name=section_title, add_hashtags=False - ) + section_query = article_with_outline.get_outline_as_list(root_section_name=section_title, + add_hashtags=False) queries_with_hashtags = article_with_outline.get_outline_as_list( - root_section_name=section_title, add_hashtags=True - ) + root_section_name=section_title, add_hashtags=True) section_outline = "\n".join(queries_with_hashtags) future_to_sec_title[ - executor.submit( - self.generate_section, - topic, - section_title, - information_table, - section_outline, - section_query, - ) + executor.submit(self.generate_section, + topic, section_title, information_table, section_outline, section_query) ] = section_title for future in as_completed(future_to_sec_title): @@ -125,11 +99,9 @@ def generate_article( article = copy.deepcopy(article_with_outline) for section_output_dict in section_output_dict_collection: - article.update_section( - parent_section_name=topic, - current_section_content=section_output_dict["section_content"], - current_section_info_list=section_output_dict["collected_info"], - ) + article.update_section(parent_section_name=topic, + current_section_content=section_output_dict["section_content"], + current_section_info_list=section_output_dict["collected_info"]) article.post_processing() return article @@ -142,24 +114,17 @@ def __init__(self, engine: Union[dspy.dsp.LM, dspy.dsp.HFModel]): self.write_section = dspy.Predict(WriteSection) self.engine = engine - def forward( - self, - topic: str, - outline: str, - section: str, - collected_info: List[StormInformation], - ): - info = "" + def forward(self, topic: str, outline: str, section: str, collected_info: List[StormInformation]): + info = '' for idx, storm_info in enumerate(collected_info): - info += f"[{idx + 1}]\n" + "\n".join(storm_info.snippets) - info += "\n\n" + info += f'[{idx + 1}]\n' + '\n'.join(storm_info.snippets) + info += '\n\n' info = ArticleTextProcessing.limit_word_count_preserve_newline(info, 1500) with dspy.settings.context(lm=self.engine): section = ArticleTextProcessing.clean_up_section( - self.write_section(topic=topic, info=info, section=section).output - ) + self.write_section(topic=topic, info=info, section=section).output) return dspy.Prediction(section=section) @@ -167,9 +132,9 @@ def forward( class WriteSection(dspy.Signature): """Write a Wikipedia section based on the collected information. - Here is the format of your writing: - 1. Use "#" Title" to indicate section title, "##" Title" to indicate subsection title, "###" Title" to indicate subsubsection title, and so on. - 2. Use [1], [2], ..., [n] in line (for example, "The capital of the United States is Washington, D.C.[1][3]."). You DO NOT need to include a References or Sources section to list the sources at the end. + Here is the format of your writing: + 1. Use "#" Title" to indicate section title, "##" Title" to indicate subsection title, "###" Title" to indicate subsubsection title, and so on. + 2. Use [1], [2], ..., [n] in line (for example, "The capital of the United States is Washington, D.C.[1][3]."). You DO NOT need to include a References or Sources section to list the sources at the end. """ info = dspy.InputField(prefix="The collected information:\n", format=str) @@ -177,5 +142,5 @@ class WriteSection(dspy.Signature): section = dspy.InputField(prefix="The section you need to write: ", format=str) output = dspy.OutputField( prefix="Write the section with proper inline citations (Start your writing with # section title. Don't include the page title or try to write other sections):\n", - format=str, + format=str ) diff --git a/knowledge_storm/storm_wiki/modules/article_polish.py b/knowledge_storm/storm_wiki/modules/article_polish.py index fb85b0f3..b70bb834 100644 --- a/knowledge_storm/storm_wiki/modules/article_polish.py +++ b/knowledge_storm/storm_wiki/modules/article_polish.py @@ -14,21 +14,21 @@ class StormArticlePolishingModule(ArticlePolishingModule): knowledge curation stage, generated outline from outline generation stage. """ - def __init__( - self, - article_gen_lm: Union[dspy.dsp.LM, dspy.dsp.HFModel], - article_polish_lm: Union[dspy.dsp.LM, dspy.dsp.HFModel], - ): + def __init__(self, + article_gen_lm: Union[dspy.dsp.LM, dspy.dsp.HFModel], + article_polish_lm: Union[dspy.dsp.LM, dspy.dsp.HFModel]): self.article_gen_lm = article_gen_lm self.article_polish_lm = article_polish_lm self.polish_page = PolishPageModule( - write_lead_engine=self.article_gen_lm, polish_engine=self.article_polish_lm + write_lead_engine=self.article_gen_lm, + polish_engine=self.article_polish_lm ) - def polish_article( - self, topic: str, draft_article: StormArticle, remove_duplicate: bool = False - ) -> StormArticle: + def polish_article(self, + topic: str, + draft_article: StormArticle, + remove_duplicate: bool = False) -> StormArticle: """ Polish article. @@ -39,14 +39,10 @@ def polish_article( """ article_text = draft_article.to_string() - polish_result = self.polish_page( - topic=topic, draft_page=article_text, polish_whole_page=remove_duplicate - ) + polish_result = self.polish_page(topic=topic, draft_page=article_text, polish_whole_page=remove_duplicate) lead_section = f"# summary\n{polish_result.lead_section}" - polished_article = "\n\n".join([lead_section, polish_result.page]) - polished_article_dict = ArticleTextProcessing.parse_article_into_dict( - polished_article - ) + polished_article = '\n\n'.join([lead_section, polish_result.page]) + polished_article_dict = ArticleTextProcessing.parse_article_into_dict(polished_article) polished_article = copy.deepcopy(draft_article) polished_article.insert_or_create_section(article_dict=polished_article_dict) polished_article.post_processing() @@ -55,10 +51,9 @@ def polish_article( class WriteLeadSection(dspy.Signature): """Write a lead section for the given Wikipedia page with the following guidelines: - 1. The lead should stand on its own as a concise overview of the article's topic. It should identify the topic, establish context, explain why the topic is notable, and summarize the most important points, including any prominent controversies. - 2. The lead section should be concise and contain no more than four well-composed paragraphs. - 3. The lead section should be carefully sourced as appropriate. Add inline citations (e.g., "Washington, D.C., is the capital of the United States.[1][3].") where necessary. - """ + 1. The lead should stand on its own as a concise overview of the article's topic. It should identify the topic, establish context, explain why the topic is notable, and summarize the most important points, including any prominent controversies. + 2. The lead section should be concise and contain no more than four well-composed paragraphs. + 3. The lead section should be carefully sourced as appropriate. Add inline citations (e.g., "Washington, D.C., is the capital of the United States.[1][3].") where necessary.""" topic = dspy.InputField(prefix="The topic of the page: ", format=str) draft_page = dspy.InputField(prefix="The draft page:\n", format=str) @@ -73,11 +68,8 @@ class PolishPage(dspy.Signature): class PolishPageModule(dspy.Module): - def __init__( - self, - write_lead_engine: Union[dspy.dsp.LM, dspy.dsp.HFModel], - polish_engine: Union[dspy.dsp.LM, dspy.dsp.HFModel], - ): + def __init__(self, write_lead_engine: Union[dspy.dsp.LM, dspy.dsp.HFModel], + polish_engine: Union[dspy.dsp.LM, dspy.dsp.HFModel]): super().__init__() self.write_lead_engine = write_lead_engine self.polish_engine = polish_engine @@ -86,9 +78,7 @@ def __init__( def forward(self, topic: str, draft_page: str, polish_whole_page: bool = True): with dspy.settings.context(lm=self.write_lead_engine): - lead_section = self.write_lead( - topic=topic, draft_page=draft_page - ).lead_section + lead_section = self.write_lead(topic=topic, draft_page=draft_page).lead_section if "The lead section:" in lead_section: lead_section = lead_section.split("The lead section:")[1].strip() if polish_whole_page: diff --git a/knowledge_storm/storm_wiki/modules/knowledge_curation.py b/knowledge_storm/storm_wiki/modules/knowledge_curation.py index bde27678..8e881c65 100644 --- a/knowledge_storm/storm_wiki/modules/knowledge_curation.py +++ b/knowledge_storm/storm_wiki/modules/knowledge_curation.py @@ -25,32 +25,20 @@ class ConvSimulator(dspy.Module): """Simulate a conversation between a Wikipedia writer with specific persona and an expert.""" - def __init__( - self, - topic_expert_engine: Union[dspy.dsp.LM, dspy.dsp.HFModel], - question_asker_engine: Union[dspy.dsp.LM, dspy.dsp.HFModel], - retriever: Retriever, - max_search_queries_per_turn: int, - search_top_k: int, - max_turn: int, - ): + def __init__(self, topic_expert_engine: Union[dspy.dsp.LM, dspy.dsp.HFModel], + question_asker_engine: Union[dspy.dsp.LM, dspy.dsp.HFModel], + retriever: Retriever, max_search_queries_per_turn: int, search_top_k: int, max_turn: int): super().__init__() self.wiki_writer = WikiWriter(engine=question_asker_engine) self.topic_expert = TopicExpert( engine=topic_expert_engine, max_search_queries=max_search_queries_per_turn, search_top_k=search_top_k, - retriever=retriever, + retriever=retriever ) self.max_turn = max_turn - def forward( - self, - topic: str, - persona: str, - ground_truth_url: str, - callback_handler: BaseCallbackHandler, - ): + def forward(self, topic: str, persona: str, ground_truth_url: str, callback_handler: BaseCallbackHandler): """ topic: The topic to research. persona: The persona of the Wikipedia writer. @@ -58,22 +46,18 @@ def forward( """ dlg_history: List[DialogueTurn] = [] for _ in range(self.max_turn): - user_utterance = self.wiki_writer( - topic=topic, persona=persona, dialogue_turns=dlg_history - ).question - if user_utterance == "": - logging.error("Simulated Wikipedia writer utterance is empty.") + user_utterance = self.wiki_writer(topic=topic, persona=persona, dialogue_turns=dlg_history).question + if user_utterance == '': + logging.error('Simulated Wikipedia writer utterance is empty.') break - if user_utterance.startswith("Thank you so much for your help!"): + if user_utterance.startswith('Thank you so much for your help!'): break - expert_output = self.topic_expert( - topic=topic, question=user_utterance, ground_truth_url=ground_truth_url - ) + expert_output = self.topic_expert(topic=topic, question=user_utterance, ground_truth_url=ground_truth_url) dlg_turn = DialogueTurn( agent_utterance=expert_output.answer, user_utterance=user_utterance, search_queries=expert_output.queries, - search_results=expert_output.searched_results, + search_results=expert_output.searched_results ) dlg_history.append(dlg_turn) callback_handler.on_dialogue_turn_end(dlg_turn=dlg_turn) @@ -92,35 +76,22 @@ def __init__(self, engine: Union[dspy.dsp.LM, dspy.dsp.HFModel]): self.ask_question = dspy.ChainOfThought(AskQuestion) self.engine = engine - def forward( - self, - topic: str, - persona: str, - dialogue_turns: List[DialogueTurn], - draft_page=None, - ): + def forward(self, topic: str, persona: str, dialogue_turns: List[DialogueTurn], draft_page=None): conv = [] for turn in dialogue_turns[:-4]: - conv.append( - f"You: {turn.user_utterance}\nExpert: Omit the answer here due to space limit." - ) + conv.append(f'You: {turn.user_utterance}\nExpert: Omit the answer here due to space limit.') for turn in dialogue_turns[-4:]: conv.append( - f"You: {turn.user_utterance}\nExpert: {ArticleTextProcessing.remove_citations(turn.agent_utterance)}" - ) - conv = "\n".join(conv) - conv = conv.strip() or "N/A" + f'You: {turn.user_utterance}\nExpert: {ArticleTextProcessing.remove_citations(turn.agent_utterance)}') + conv = '\n'.join(conv) + conv = conv.strip() or 'N/A' conv = ArticleTextProcessing.limit_word_count_preserve_newline(conv, 2500) with dspy.settings.context(lm=self.engine): if persona is not None and len(persona.strip()) > 0: - question = self.ask_question_with_persona( - topic=topic, persona=persona, conv=conv - ).question + question = self.ask_question_with_persona(topic=topic, persona=persona, conv=conv).question else: - question = self.ask_question( - topic=topic, persona=persona, conv=conv - ).question + question = self.ask_question(topic=topic, persona=persona, conv=conv).question return dspy.Prediction(question=question) @@ -128,11 +99,10 @@ def forward( class AskQuestion(dspy.Signature): """You are an experienced Wikipedia writer. You are chatting with an expert to get information for the topic you want to contribute. Ask good questions to get more useful information relevant to the topic. When you have no more question to ask, say "Thank you so much for your help!" to end the conversation. - Please only ask a question at a time and don't ask what you have asked before. Your questions should be related to the topic you want to write. - """ + Please only ask a question at a time and don't ask what you have asked before. Your questions should be related to the topic you want to write.""" - topic = dspy.InputField(prefix="Topic you want to write: ", format=str) - conv = dspy.InputField(prefix="Conversation history:\n", format=str) + topic = dspy.InputField(prefix='Topic you want to write: ', format=str) + conv = dspy.InputField(prefix='Conversation history:\n', format=str) question = dspy.OutputField(format=str) @@ -140,41 +110,38 @@ class AskQuestionWithPersona(dspy.Signature): """You are an experienced Wikipedia writer and want to edit a specific page. Besides your identity as a Wikipedia writer, you have specific focus when researching the topic. Now, you are chatting with an expert to get information. Ask good questions to get more useful information. When you have no more question to ask, say "Thank you so much for your help!" to end the conversation. - Please only ask a question at a time and don't ask what you have asked before. Your questions should be related to the topic you want to write. - """ + Please only ask a question at a time and don't ask what you have asked before. Your questions should be related to the topic you want to write.""" - topic = dspy.InputField(prefix="Topic you want to write: ", format=str) - persona = dspy.InputField( - prefix="Your persona besides being a Wikipedia writer: ", format=str - ) - conv = dspy.InputField(prefix="Conversation history:\n", format=str) + topic = dspy.InputField(prefix='Topic you want to write: ', format=str) + persona = dspy.InputField(prefix='Your persona besides being a Wikipedia writer: ', format=str) + conv = dspy.InputField(prefix='Conversation history:\n', format=str) question = dspy.OutputField(format=str) class QuestionToQuery(dspy.Signature): """You want to answer the question using Google search. What do you type in the search box? - Write the queries you will use in the following format: - - query 1 - - query 2 - ... - - query n""" - - topic = dspy.InputField(prefix="Topic you are discussing about: ", format=str) - question = dspy.InputField(prefix="Question you want to answer: ", format=str) + Write the queries you will use in the following format: + - query 1 + - query 2 + ... + - query n""" + + topic = dspy.InputField(prefix='Topic you are discussing about: ', format=str) + question = dspy.InputField(prefix='Question you want to answer: ', format=str) queries = dspy.OutputField(format=str) class AnswerQuestion(dspy.Signature): """You are an expert who can use information effectively. You are chatting with a Wikipedia writer who wants to write a Wikipedia page on topic you know. You have gathered the related information and will now use the information to form a response. - Make your response as informative as possible and make sure every sentence is supported by the gathered information. If [Gathered information] is not related to he [Topic] and [Question], output "Sorry, I don't have enough information to answer the question.". - """ + Make your response as informative as possible and make sure every sentence is supported by the gathered information. If [Gathered information] is not related to he [Topic] and [Question], output "Sorry, I don't have enough information to answer the question.".""" - topic = dspy.InputField(prefix="Topic you are discussing about:", format=str) - conv = dspy.InputField(prefix="Question:\n", format=str) - info = dspy.InputField(prefix="Gathered information:\n", format=str) + topic = dspy.InputField(prefix='Topic you are discussing about:', format=str) + conv = dspy.InputField(prefix='Question:\n', format=str) + info = dspy.InputField( + prefix='Gathered information:\n', format=str) answer = dspy.OutputField( - prefix="Now give your response. (Try to use as many different sources as possible and add do not hallucinate.)\n", - format=str, + prefix='Now give your response. (Try to use as many different sources as possible and add do not hallucinate.)\n', + format=str ) @@ -186,13 +153,8 @@ class TopicExpert(dspy.Module): 4. Generate an answer using the retrieved information. """ - def __init__( - self, - engine: Union[dspy.dsp.LM, dspy.dsp.HFModel], - max_search_queries: int, - search_top_k: int, - retriever: Retriever, - ): + def __init__(self, engine: Union[dspy.dsp.LM, dspy.dsp.HFModel], + max_search_queries: int, search_top_k: int, retriever: Retriever): super().__init__() self.generate_queries = dspy.Predict(QuestionToQuery) self.retriever = retriever @@ -206,43 +168,31 @@ def forward(self, topic: str, question: str, ground_truth_url: str): with dspy.settings.context(lm=self.engine): # Identify: Break down question into queries. queries = self.generate_queries(topic=topic, question=question).queries - queries = [ - q.replace("-", "").strip().strip('"').strip('"').strip() - for q in queries.split("\n") - ] - queries = queries[: self.max_search_queries] + queries = [q.replace('-', '').strip().strip('"').strip('"').strip() for q in queries.split('\n')] + queries = queries[:self.max_search_queries] # Search - searched_results: List[StormInformation] = self.retriever.retrieve( - list(set(queries)), exclude_urls=[ground_truth_url] - ) + searched_results: List[StormInformation] = self.retriever.retrieve(list(set(queries)), + exclude_urls=[ground_truth_url]) if len(searched_results) > 0: # Evaluate: Simplify this part by directly using the top 1 snippet. - info = "" + info = '' for n, r in enumerate(searched_results): - info += "\n".join(f"[{n + 1}]: {s}" for s in r.snippets[:1]) - info += "\n\n" + info += '\n'.join(f'[{n + 1}]: {s}' for s in r.snippets[:1]) + info += '\n\n' - info = ArticleTextProcessing.limit_word_count_preserve_newline( - info, 1000 - ) + info = ArticleTextProcessing.limit_word_count_preserve_newline(info, 1000) try: - answer = self.answer_question( - topic=topic, conv=question, info=info - ).answer - answer = ArticleTextProcessing.remove_uncompleted_sentences_with_citations( - answer - ) + answer = self.answer_question(topic=topic, conv=question, info=info).answer + answer = ArticleTextProcessing.remove_uncompleted_sentences_with_citations(answer) except Exception as e: - logging.error(f"Error occurs when generating answer: {e}") - answer = "Sorry, I cannot answer this question. Please ask another question." + logging.error(f'Error occurs when generating answer: {e}') + answer = 'Sorry, I cannot answer this question. Please ask another question.' else: # When no information is found, the expert shouldn't hallucinate. - answer = "Sorry, I cannot find information for this question. Please ask another question." + answer = 'Sorry, I cannot find information for this question. Please ask another question.' - return dspy.Prediction( - queries=queries, searched_results=searched_results, answer=answer - ) + return dspy.Prediction(queries=queries, searched_results=searched_results, answer=answer) class StormKnowledgeCurationModule(KnowledgeCurationModule): @@ -250,17 +200,15 @@ class StormKnowledgeCurationModule(KnowledgeCurationModule): The interface for knowledge curation stage. Given topic, return collected information. """ - def __init__( - self, - retriever: Retriever, - persona_generator: Optional[StormPersonaGenerator], - conv_simulator_lm: Union[dspy.dsp.LM, dspy.dsp.HFModel], - question_asker_lm: Union[dspy.dsp.LM, dspy.dsp.HFModel], - max_search_queries_per_turn: int, - search_top_k: int, - max_conv_turn: int, - max_thread_num: int, - ): + def __init__(self, + retriever: Retriever, + persona_generator: Optional[StormPersonaGenerator], + conv_simulator_lm: Union[dspy.dsp.LM, dspy.dsp.HFModel], + question_asker_lm: Union[dspy.dsp.LM, dspy.dsp.HFModel], + max_search_queries_per_turn: int, + search_top_k: int, + max_conv_turn: int, + max_thread_num: int): """ Store args and finish initialization. """ @@ -276,22 +224,14 @@ def __init__( retriever=retriever, max_search_queries_per_turn=max_search_queries_per_turn, search_top_k=search_top_k, - max_turn=max_conv_turn, + max_turn=max_conv_turn ) def _get_considered_personas(self, topic: str, max_num_persona) -> List[str]: - return self.persona_generator.generate_persona( - topic=topic, max_num_persona=max_num_persona - ) + return self.persona_generator.generate_persona(topic=topic, max_num_persona=max_num_persona) - def _run_conversation( - self, - conv_simulator, - topic, - ground_truth_url, - considered_personas, - callback_handler: BaseCallbackHandler, - ) -> List[Tuple[str, List[DialogueTurn]]]: + def _run_conversation(self, conv_simulator, topic, ground_truth_url, considered_personas, + callback_handler: BaseCallbackHandler) -> List[Tuple[str, List[DialogueTurn]]]: """ Executes multiple conversation simulations concurrently, each with a different persona, and collects their dialog histories. The dialog history of each conversation is cleaned @@ -320,16 +260,13 @@ def run_conv(persona): topic=topic, ground_truth_url=ground_truth_url, persona=persona, - callback_handler=callback_handler, + callback_handler=callback_handler ) max_workers = min(self.max_thread_num, len(considered_personas)) with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: - future_to_persona = { - executor.submit(run_conv, persona): persona - for persona in considered_personas - } + future_to_persona = {executor.submit(run_conv, persona): persona for persona in considered_personas} if streamlit_connection: # Ensure the logging context is correct when connecting with Streamlit frontend. @@ -339,27 +276,23 @@ def run_conv(persona): for future in as_completed(future_to_persona): persona = future_to_persona[future] conv = future.result() - conversations.append( - (persona, ArticleTextProcessing.clean_up_citation(conv).dlg_history) - ) + conversations.append((persona, ArticleTextProcessing.clean_up_citation(conv).dlg_history)) return conversations - def research( - self, - topic: str, - ground_truth_url: str, - callback_handler: BaseCallbackHandler, - max_perspective: int = 0, - disable_perspective: bool = True, - return_conversation_log=False, - ) -> Union[StormInformationTable, Tuple[StormInformationTable, Dict]]: + def research(self, + topic: str, + ground_truth_url: str, + callback_handler: BaseCallbackHandler, + max_perspective: int = 0, + disable_perspective: bool = True, + return_conversation_log=False) -> Union[StormInformationTable, Tuple[StormInformationTable, Dict]]: """ Curate information and knowledge for the given topic Args: topic: topic of interest in natural language. - + Returns: collected_information: collected information in InformationTable type. """ @@ -370,25 +303,19 @@ def research( if disable_perspective: considered_personas = [""] else: - considered_personas = self._get_considered_personas( - topic=topic, max_num_persona=max_perspective - ) + considered_personas = self._get_considered_personas(topic=topic, max_num_persona=max_perspective) callback_handler.on_identify_perspective_end(perspectives=considered_personas) - # run conversation + # run conversation callback_handler.on_information_gathering_start() - conversations = self._run_conversation( - conv_simulator=self.conv_simulator, - topic=topic, - ground_truth_url=ground_truth_url, - considered_personas=considered_personas, - callback_handler=callback_handler, - ) + conversations = self._run_conversation(conv_simulator=self.conv_simulator, + topic=topic, + ground_truth_url=ground_truth_url, + considered_personas=considered_personas, + callback_handler=callback_handler) information_table = StormInformationTable(conversations) callback_handler.on_information_gathering_end() if return_conversation_log: - return information_table, StormInformationTable.construct_log_dict( - conversations - ) + return information_table, StormInformationTable.construct_log_dict(conversations) return information_table diff --git a/knowledge_storm/storm_wiki/modules/outline_generation.py b/knowledge_storm/storm_wiki/modules/outline_generation.py index a96c7978..1f45b1c2 100644 --- a/knowledge_storm/storm_wiki/modules/outline_generation.py +++ b/knowledge_storm/storm_wiki/modules/outline_generation.py @@ -14,19 +14,18 @@ class StormOutlineGenerationModule(OutlineGenerationModule): curation stage, generate outline for the article. """ - def __init__(self, outline_gen_lm: Union[dspy.dsp.LM, dspy.dsp.HFModel]): + def __init__(self, + outline_gen_lm: Union[dspy.dsp.LM, dspy.dsp.HFModel]): super().__init__() self.outline_gen_lm = outline_gen_lm self.write_outline = WriteOutline(engine=self.outline_gen_lm) - def generate_outline( - self, - topic: str, - information_table: StormInformationTable, - old_outline: Optional[StormArticle] = None, - callback_handler: BaseCallbackHandler = None, - return_draft_outline=False, - ) -> Union[StormArticle, Tuple[StormArticle, StormArticle]]: + def generate_outline(self, + topic: str, + information_table: StormInformationTable, + old_outline: Optional[StormArticle] = None, + callback_handler: BaseCallbackHandler = None, + return_draft_outline=False) -> Union[StormArticle, Tuple[StormArticle, StormArticle]]: """ Generates an outline for an article based on the specified topic and the information gathered during the knowledge curation stage. This method can optionally return both the @@ -35,38 +34,30 @@ def generate_outline( Args: topic (str): The topic of the article. information_table (StormInformationTable): The information table containing the collected information. - old_outline (Optional[StormArticle]): An optional previous version of the article outline that can + old_outline (Optional[StormArticle]): An optional previous version of the article outline that can be used for reference or comparison. Defaults to None. - callback_handler (BaseCallbackHandler): An optional callback handler that can be used to trigger - custom callbacks at various stages of the outline generation process, such as when the information + callback_handler (BaseCallbackHandler): An optional callback handler that can be used to trigger + custom callbacks at various stages of the outline generation process, such as when the information organization starts. Defaults to None. - return_draft_outline (bool): A flag indicating whether the method should return both the final article - outline and a draft version of the outline. If False, only the final article outline is returned. + return_draft_outline (bool): A flag indicating whether the method should return both the final article + outline and a draft version of the outline. If False, only the final article outline is returned. Defaults to False. Returns: - Union[StormArticle, Tuple[StormArticle, StormArticle]]: Depending on the value of `return_draft_outline`, - this method returns either a single `StormArticle` object containing the final outline or a tuple of - two `StormArticle` objects, the first containing the final outline and the second containing the + Union[StormArticle, Tuple[StormArticle, StormArticle]]: Depending on the value of `return_draft_outline`, + this method returns either a single `StormArticle` object containing the final outline or a tuple of + two `StormArticle` objects, the first containing the final outline and the second containing the draft outline. """ if callback_handler is not None: callback_handler.on_information_organization_start() - concatenated_dialogue_turns = sum( - [conv for (_, conv) in information_table.conversations], [] - ) - result = self.write_outline( - topic=topic, - dlg_history=concatenated_dialogue_turns, - callback_handler=callback_handler, - ) - article_with_outline_only = StormArticle.from_outline_str( - topic=topic, outline_str=result.outline - ) - article_with_draft_outline_only = StormArticle.from_outline_str( - topic=topic, outline_str=result.old_outline - ) + concatenated_dialogue_turns = sum([conv for (_, conv) in information_table.conversations], []) + result = self.write_outline(topic=topic, dlg_history=concatenated_dialogue_turns, + callback_handler=callback_handler) + article_with_outline_only = StormArticle.from_outline_str(topic=topic, outline_str=result.outline) + article_with_draft_outline_only = StormArticle.from_outline_str(topic=topic, + outline_str=result.old_outline) if not return_draft_outline: return article_with_outline_only return article_with_outline_only, article_with_draft_outline_only @@ -81,44 +72,25 @@ def __init__(self, engine: Union[dspy.dsp.LM, dspy.dsp.HFModel]): self.write_page_outline = dspy.Predict(WritePageOutlineFromConv) self.engine = engine - def forward( - self, - topic: str, - dlg_history, - old_outline: Optional[str] = None, - callback_handler: BaseCallbackHandler = None, - ): + def forward(self, topic: str, dlg_history, old_outline: Optional[str] = None, + callback_handler: BaseCallbackHandler = None): trimmed_dlg_history = [] for turn in dlg_history: - if ( - "topic you" in turn.agent_utterance.lower() - or "topic you" in turn.user_utterance.lower() - ): + if 'topic you' in turn.agent_utterance.lower() or 'topic you' in turn.user_utterance.lower(): continue trimmed_dlg_history.append(turn) - conv = "\n".join( - [ - f"Wikipedia Writer: {turn.user_utterance}\nExpert: {turn.agent_utterance}" - for turn in trimmed_dlg_history - ] - ) + conv = '\n'.join([f'Wikipedia Writer: {turn.user_utterance}\nExpert: {turn.agent_utterance}' for turn in + trimmed_dlg_history]) conv = ArticleTextProcessing.remove_citations(conv) conv = ArticleTextProcessing.limit_word_count_preserve_newline(conv, 5000) with dspy.settings.context(lm=self.engine): if old_outline is None: - old_outline = ArticleTextProcessing.clean_up_outline( - self.draft_page_outline(topic=topic).outline - ) + old_outline = ArticleTextProcessing.clean_up_outline(self.draft_page_outline(topic=topic).outline) if callback_handler: - callback_handler.on_direct_outline_generation_end( - outline=old_outline - ) + callback_handler.on_direct_outline_generation_end(outline=old_outline) outline = ArticleTextProcessing.clean_up_outline( - self.write_page_outline( - topic=topic, old_outline=old_outline, conv=conv - ).outline - ) + self.write_page_outline(topic=topic, old_outline=old_outline, conv=conv).outline) if callback_handler: callback_handler.on_outline_refinement_end(outline=outline) @@ -127,10 +99,10 @@ def forward( class WritePageOutline(dspy.Signature): """Write an outline for a Wikipedia page. - Here is the format of your writing: - 1. Use "#" Title" to indicate section title, "##" Title" to indicate subsection title, "###" Title" to indicate subsubsection title, and so on. - 2. Do not include other information. - 3. Do not include topic name itself in the outline. + Here is the format of your writing: + 1. Use "#" Title" to indicate section title, "##" Title" to indicate subsection title, "###" Title" to indicate subsubsection title, and so on. + 2. Do not include other information. + 3. Do not include topic name itself in the outline. """ topic = dspy.InputField(prefix="The topic you want to write: ", format=str) @@ -152,10 +124,10 @@ def forward(self, topic: str): class WritePageOutlineFromConv(dspy.Signature): """Improve an outline for a Wikipedia page. You already have a draft outline that covers the general information. Now you want to improve it based on the information learned from an information-seeking conversation to make it more informative. - Here is the format of your writing: - 1. Use "#" Title" to indicate section title, "##" Title" to indicate subsection title, "###" Title" to indicate subsubsection title, and so on. - 2. Do not include other information. - 3. Do not include topic name itself in the outline. + Here is the format of your writing: + 1. Use "#" Title" to indicate section title, "##" Title" to indicate subsection title, "###" Title" to indicate subsubsection title, and so on. + 2. Do not include other information. + 3. Do not include topic name itself in the outline. """ topic = dspy.InputField(prefix="The topic you want to write: ", format=str) @@ -163,5 +135,5 @@ class WritePageOutlineFromConv(dspy.Signature): old_outline = dspy.OutputField(prefix="Current outline:\n", format=str) outline = dspy.OutputField( prefix='Write the Wikipedia page outline (Use "#" Title" to indicate section title, "##" Title" to indicate subsection title, ...):\n', - format=str, + format=str ) diff --git a/knowledge_storm/storm_wiki/modules/persona_generator.py b/knowledge_storm/storm_wiki/modules/persona_generator.py index c51dc0cc..5150e31b 100644 --- a/knowledge_storm/storm_wiki/modules/persona_generator.py +++ b/knowledge_storm/storm_wiki/modules/persona_generator.py @@ -11,27 +11,19 @@ def get_wiki_page_title_and_toc(url): """Get the main title and table of contents from an url of a Wikipedia page.""" response = requests.get(url) - soup = BeautifulSoup(response.content, "html.parser") + soup = BeautifulSoup(response.content, 'html.parser') # Get the main title from the first h1 tag - main_title = soup.find("h1").text.replace("[edit]", "").strip().replace("\xa0", " ") + main_title = soup.find('h1').text.replace('[edit]', '').strip().replace('\xa0', ' ') toc = "" levels = [] - excluded_sections = { - "Contents", - "See also", - "Notes", - "References", - "External links", - } + excluded_sections = {'Contents', 'See also', 'Notes', 'References', 'External links'} # Start processing from h2 to exclude the main title from TOC - for header in soup.find_all(["h2", "h3", "h4", "h5", "h6"]): - level = int( - header.name[1] - ) # Extract the numeric part of the header tag (e.g., '2' from 'h2') - section_title = header.text.replace("[edit]", "").strip().replace("\xa0", " ") + for header in soup.find_all(['h2', 'h3', "h4", "h5", "h6"]): + level = int(header.name[1]) # Extract the numeric part of the header tag (e.g., '2' from 'h2') + section_title = header.text.replace('[edit]', '').strip().replace('\xa0', ' ') if section_title in excluded_sections: continue @@ -47,9 +39,9 @@ def get_wiki_page_title_and_toc(url): class FindRelatedTopic(dspy.Signature): """I'm writing a Wikipedia page for a topic mentioned below. Please identify and recommend some Wikipedia pages on closely related subjects. I'm looking for examples that provide insights into interesting aspects commonly associated with this topic, or examples that help me understand the typical content and structure included in Wikipedia pages for similar topics. - Please list the urls in separate lines.""" + Please list the urls in separate lines.""" - topic = dspy.InputField(prefix="Topic of interest:", format=str) + topic = dspy.InputField(prefix='Topic of interest:', format=str) related_topics = dspy.OutputField(format=str) @@ -58,10 +50,8 @@ class GenPersona(dspy.Signature): Give your answer in the following format: 1. short summary of editor 1: description\n2. short summary of editor 2: description\n... """ - topic = dspy.InputField(prefix="Topic of interest:", format=str) - examples = dspy.InputField( - prefix="Wiki page outlines of related topics for inspiration:\n", format=str - ) + topic = dspy.InputField(prefix='Topic of interest:', format=str) + examples = dspy.InputField(prefix='Wiki page outlines of related topics for inspiration:\n', format=str) personas = dspy.OutputField(format=str) @@ -79,44 +69,38 @@ def forward(self, topic: str, draft=None): # Get section names from wiki pages of relevant topics for inspiration. related_topics = self.find_related_topic(topic=topic).related_topics urls = [] - for s in related_topics.split("\n"): - if "http" in s: - urls.append(s[s.find("http") :]) + for s in related_topics.split('\n'): + if 'http' in s: + urls.append(s[s.find('http'):]) examples = [] for url in urls: try: title, toc = get_wiki_page_title_and_toc(url) - examples.append(f"Title: {title}\nTable of Contents: {toc}") + examples.append(f'Title: {title}\nTable of Contents: {toc}') except Exception as e: - logging.error(f"Error occurs when processing {url}: {e}") + logging.error(f'Error occurs when processing {url}: {e}') continue if len(examples) == 0: - examples.append("N/A") - gen_persona_output = self.gen_persona( - topic=topic, examples="\n----------\n".join(examples) - ).personas + examples.append('N/A') + gen_persona_output = self.gen_persona(topic=topic, examples='\n----------\n'.join(examples)).personas personas = [] - for s in gen_persona_output.split("\n"): - match = re.search(r"\d+\.\s*(.*)", s) + for s in gen_persona_output.split('\n'): + match = re.search(r'\d+\.\s*(.*)', s) if match: personas.append(match.group(1)) sorted_personas = personas - return dspy.Prediction( - personas=personas, - raw_personas_output=sorted_personas, - related_topics=related_topics, - ) + return dspy.Prediction(personas=personas, raw_personas_output=sorted_personas, related_topics=related_topics) -class StormPersonaGenerator: +class StormPersonaGenerator(): """ A generator class for creating personas based on a given topic. - This class uses an underlying engine to generate personas tailored to the specified topic. - The generator integrates with a `CreateWriterWithPersona` instance to create diverse personas, + This class uses an underlying engine to generate personas tailored to the specified topic. + The generator integrates with a `CreateWriterWithPersona` instance to create diverse personas, including a default 'Basic fact writer' persona. Attributes: @@ -149,6 +133,6 @@ def generate_persona(self, topic: str, max_num_persona: int = 3) -> List[str]: and up to `max_num_persona` additional personas generated based on the topic. """ personas = self.create_writer_with_persona(topic=topic) - default_persona = "Basic fact writer: Basic fact writer focusing on broadly covering the basic facts about the topic." + default_persona = 'Basic fact writer: Basic fact writer focusing on broadly covering the basic facts about the topic.' considered_personas = [default_persona] + personas.personas[:max_num_persona] return considered_personas diff --git a/knowledge_storm/storm_wiki/modules/retriever.py b/knowledge_storm/storm_wiki/modules/retriever.py index 85df63ec..179ae99b 100644 --- a/knowledge_storm/storm_wiki/modules/retriever.py +++ b/knowledge_storm/storm_wiki/modules/retriever.py @@ -149,8 +149,7 @@ "WordPress.com", "Worldometer", "YouTube", - "ZDNet", -} + "ZDNet"} DEPRECATED = { "Al_Mayadeen", "ANNA_News", @@ -198,7 +197,7 @@ "VDARE", "Voltaire_Network", "WorldNetDaily", - "Zero_Hedge", + "Zero_Hedge" } BLACKLISTED = { "Advameg", @@ -219,7 +218,7 @@ "The_Points_Guy_(sponsored_content)", "Swarajya", "Veterans_Today", - "ZoomInfo", + "ZoomInfo" } @@ -238,20 +237,14 @@ class StormRetriever(Retriever): def __init__(self, rm: dspy.Retrieve, k=3): super().__init__(search_top_k=k) self._rm = rm - if hasattr(rm, "is_valid_source"): + if hasattr(rm, 'is_valid_source'): rm.is_valid_source = is_valid_wikipedia_source - def retrieve( - self, query: Union[str, List[str]], exclude_urls: List[str] = [] - ) -> List[Information]: - retrieved_data_list = self._rm( - query_or_queries=query, exclude_urls=exclude_urls - ) + def retrieve(self, query: Union[str, List[str]], exclude_urls: List[str] = []) -> List[Information]: + retrieved_data_list = self._rm(query_or_queries=query, exclude_urls=exclude_urls) for data in retrieved_data_list: - for i in range(len(data["snippets"])): + for i in range(len(data['snippets'])): # STORM generate the article with citations. We do not consider multi-hop citations. # Remove citations in the source to avoid confusion. - data["snippets"][i] = ArticleTextProcessing.remove_citations( - data["snippets"][i] - ) + data['snippets'][i] = ArticleTextProcessing.remove_citations(data['snippets'][i]) return [StormInformation.from_dict(data) for data in retrieved_data_list] diff --git a/knowledge_storm/storm_wiki/modules/storm_dataclass.py b/knowledge_storm/storm_wiki/modules/storm_dataclass.py index 43826ecc..4f54ec46 100644 --- a/knowledge_storm/storm_wiki/modules/storm_dataclass.py +++ b/knowledge_storm/storm_wiki/modules/storm_dataclass.py @@ -51,29 +51,22 @@ def from_dict(cls, info_dict): Returns: StormInformation: An instance of StormInformation. """ - return cls( - info_dict["url"], - info_dict["description"], - info_dict["snippets"], - info_dict["title"], - ) + return cls(info_dict['url'], info_dict['description'], info_dict['snippets'], info_dict['title']) def to_dict(self): - return { - "url": self.uuid, - "description": self.description, - "snippets": self.snippets, - "title": self.title, - } + return {"url": self.uuid, + "description": self.description, + "snippets": self.snippets, + "title": self.title} class DialogueTurn: def __init__( - self, - agent_utterance: str = None, - user_utterance: str = None, - search_queries: Optional[List[str]] = None, - search_results: Optional[List[Union[StormInformation, Dict]]] = None, + self, + agent_utterance: str = None, + user_utterance: str = None, + search_queries: Optional[List[str]] = None, + search_results: Optional[List[Union[StormInformation, Dict]]] = None ): self.agent_utterance = agent_utterance self.user_utterance = user_utterance @@ -83,9 +76,7 @@ def __init__( if self.search_results: for idx in range(len(self.search_results)): if type(self.search_results[idx]) == dict: - self.search_results[idx] = StormInformation.from_dict( - self.search_results[idx] - ) + self.search_results[idx] = StormInformation.from_dict(self.search_results[idx]) def log(self): """ @@ -94,10 +85,10 @@ def log(self): return OrderedDict( { - "agent_utterance": self.agent_utterance, - "user_utterance": self.user_utterance, - "search_queries": self.search_queries, - "search_results": [data.to_dict() for data in self.search_results], + 'agent_utterance': self.agent_utterance, + 'user_utterance': self.user_utterance, + 'search_queries': self.search_queries, + 'search_results': [data.to_dict() for data in self.search_results], } ) @@ -107,7 +98,7 @@ class StormInformationTable(InformationTable): The InformationTable class serves as data class to store the information collected during KnowledgeCuration stage. - Create subclass to incorporate more information as needed. For example, + Create subclass to incorporate more information as needed. For example, in STORM paper https://arxiv.org/pdf/2402.14207.pdf, additional information would be perspective guided dialogue history. """ @@ -115,17 +106,13 @@ class StormInformationTable(InformationTable): def __init__(self, conversations=List[Tuple[str, List[DialogueTurn]]]): super().__init__() self.conversations = conversations - self.url_to_info: Dict[str, StormInformation] = ( - StormInformationTable.construct_url_to_info(self.conversations) - ) + self.url_to_info: Dict[str, StormInformation] = StormInformationTable.construct_url_to_info(self.conversations) @staticmethod - def construct_url_to_info( - conversations: List[Tuple[str, List[DialogueTurn]]] - ) -> Dict[str, StormInformation]: + def construct_url_to_info(conversations: List[Tuple[str, List[DialogueTurn]]]) -> Dict[str, StormInformation]: url_to_info = {} - for persona, conv in conversations: + for (persona, conv) in conversations: for turn in conv: for storm_info in turn.search_results: if storm_info.url in url_to_info: @@ -137,13 +124,14 @@ def construct_url_to_info( return url_to_info @staticmethod - def construct_log_dict( - conversations: List[Tuple[str, List[DialogueTurn]]] - ) -> List[Dict[str, Union[str, Any]]]: + def construct_log_dict(conversations: List[Tuple[str, List[DialogueTurn]]]) -> List[Dict[str, Union[str, Any]]]: conversation_log = [] - for persona, conv in conversations: + for (persona, conv) in conversations: conversation_log.append( - {"perspective": persona, "dlg_turns": [turn.log() for turn in conv]} + { + 'perspective': persona, + 'dlg_turns': [turn.log() for turn in conv] + } ) return conversation_log @@ -158,26 +146,22 @@ def from_conversation_log_file(cls, path): conversation_log_data = FileIOHelper.load_json(path) conversations = [] for item in conversation_log_data: - dialogue_turns = [DialogueTurn(**turn) for turn in item["dlg_turns"]] - persona = item["perspective"] + dialogue_turns = [DialogueTurn(**turn) for turn in item['dlg_turns']] + persona = item['perspective'] conversations.append((persona, dialogue_turns)) return cls(conversations) def prepare_table_for_retrieval(self): - self.encoder = SentenceTransformer("paraphrase-MiniLM-L6-v2") + self.encoder = SentenceTransformer('paraphrase-MiniLM-L6-v2') self.collected_urls = [] self.collected_snippets = [] for url, information in self.url_to_info.items(): for snippet in information.snippets: self.collected_urls.append(url) self.collected_snippets.append(snippet) - self.encoded_snippets = self.encoder.encode( - self.collected_snippets, show_progress_bar=False - ) + self.encoded_snippets = self.encoder.encode(self.collected_snippets, show_progress_bar=False) - def retrieve_information( - self, queries: Union[List[str], str], search_top_k - ) -> List[StormInformation]: + def retrieve_information(self, queries: Union[List[str], str], search_top_k) -> List[StormInformation]: selected_urls = [] selected_snippets = [] if type(queries) is str: @@ -207,13 +191,14 @@ def retrieve_information( class StormArticle(Article): def __init__(self, topic_name): super().__init__(topic_name=topic_name) - self.reference = {"url_to_unified_index": {}, "url_to_info": {}} + self.reference = { + "url_to_unified_index": {}, + "url_to_info": {} + } - def find_section( - self, node: ArticleSectionNode, name: str - ) -> Optional[ArticleSectionNode]: + def find_section(self, node: ArticleSectionNode, name: str) -> Optional[ArticleSectionNode]: """ - Return the node of the section given the section name. + Return the node of the section given the section name. Args: node: the node as the root to find. @@ -230,18 +215,17 @@ def find_section( return result return None - def _merge_new_info_to_references( - self, new_info_list: List[StormInformation], index_to_keep=None - ) -> Dict[int, int]: + def _merge_new_info_to_references(self, new_info_list: List[StormInformation], index_to_keep=None) -> Dict[ + int, int]: """ Merges new storm information into existing references and updates the citation index mapping. Args: - new_info_list (List[StormInformation]): A list of dictionaries representing new storm information. + new_info_list (List[StormInformation]): A list of dictionaries representing new storm information. index_to_keep (List[int]): A list of index of the new_info_list to keep. If none, keep all. Returns: - Dict[int, int]: A dictionary mapping the index of each storm information piece in the input list + Dict[int, int]: A dictionary mapping the index of each storm information piece in the input list to its unified citation index in the references. """ citation_idx_mapping = {} @@ -250,32 +234,20 @@ def _merge_new_info_to_references( continue url = storm_info.url if url not in self.reference["url_to_unified_index"]: - self.reference["url_to_unified_index"][url] = ( - len(self.reference["url_to_unified_index"]) + 1 - ) # The citation index starts from 1. + self.reference["url_to_unified_index"][url] = len( + self.reference["url_to_unified_index"]) + 1 # The citation index starts from 1. self.reference["url_to_info"][url] = storm_info else: existing_snippets = self.reference["url_to_info"][url].snippets existing_snippets.extend(storm_info.snippets) - self.reference["url_to_info"][url].snippets = list( - set(existing_snippets) - ) + self.reference["url_to_info"][url].snippets = list(set(existing_snippets)) citation_idx_mapping[idx + 1] = self.reference["url_to_unified_index"][ - url - ] # The citation index starts from 1. + url] # The citation index starts from 1. return citation_idx_mapping - def insert_or_create_section( - self, - article_dict: Dict[str, Dict], - parent_section_name: str = None, - trim_children=False, - ): - parent_node = ( - self.root - if parent_section_name is None - else self.find_section(self.root, parent_section_name) - ) + def insert_or_create_section(self, article_dict: Dict[str, Dict], parent_section_name: str = None, + trim_children=False): + parent_node = self.root if parent_section_name is None else self.find_section(self.root, parent_section_name) if trim_children: section_names = set(article_dict.keys()) @@ -286,83 +258,56 @@ def insert_or_create_section( for section_name, content_dict in article_dict.items(): current_section_node = self.find_section(parent_node, section_name) if current_section_node is None: - current_section_node = ArticleSectionNode( - section_name=section_name, content=content_dict["content"].strip() - ) - insert_to_front = ( - parent_node.section_name == self.root.section_name - and current_section_node.section_name == "summary" - ) - parent_node.add_child( - current_section_node, insert_to_front=insert_to_front - ) + current_section_node = ArticleSectionNode(section_name=section_name, + content=content_dict["content"].strip()) + insert_to_front = parent_node.section_name == self.root.section_name and current_section_node.section_name == "summary" + parent_node.add_child(current_section_node, insert_to_front=insert_to_front) else: current_section_node.content = content_dict["content"].strip() - self.insert_or_create_section( - article_dict=content_dict["subsections"], - parent_section_name=section_name, - trim_children=True, - ) + self.insert_or_create_section(article_dict=content_dict["subsections"], parent_section_name=section_name, + trim_children=True) - def update_section( - self, - current_section_content: str, - current_section_info_list: List[StormInformation], - parent_section_name: Optional[str] = None, - ) -> Optional[ArticleSectionNode]: + def update_section(self, + current_section_content: str, + current_section_info_list: List[StormInformation], + parent_section_name: Optional[str] = None) -> Optional[ArticleSectionNode]: """ - Add new section to the article. + Add new section to the article. Args: current_section_name: new section heading name in string format. parent_section_name: under which parent section to add the new one. Default to root. - current_section_content: optional section content. - + current_section_content: optional section content. + Returns: the ArticleSectionNode for current section if successfully created / updated. Otherwise none. """ if current_section_info_list is not None: - references = set( - [int(x) for x in re.findall(r"\[(\d+)\]", current_section_content)] - ) + references = set([int(x) for x in re.findall(r'\[(\d+)\]', current_section_content)]) # for any reference number greater than max number of references, delete the reference if len(references) > 0: max_ref_num = max(references) if max_ref_num > len(current_section_info_list): for i in range(len(current_section_info_list), max_ref_num + 1): - current_section_content = current_section_content.replace( - f"[{i}]", "" - ) + current_section_content = current_section_content.replace(f'[{i}]', '') if i in references: references.remove(i) # for any reference that is not used, trim it from current_section_info_list index_to_keep = [i - 1 for i in references] - citation_mapping = self._merge_new_info_to_references( - current_section_info_list, index_to_keep - ) - current_section_content = ArticleTextProcessing.update_citation_index( - current_section_content, citation_mapping - ) + citation_mapping = self._merge_new_info_to_references(current_section_info_list, index_to_keep) + current_section_content = ArticleTextProcessing.update_citation_index(current_section_content, + citation_mapping) if parent_section_name is None: parent_section_name = self.root.section_name - article_dict = ArticleTextProcessing.parse_article_into_dict( - current_section_content - ) - self.insert_or_create_section( - article_dict=article_dict, - parent_section_name=parent_section_name, - trim_children=False, - ) + article_dict = ArticleTextProcessing.parse_article_into_dict(current_section_content) + self.insert_or_create_section(article_dict=article_dict, parent_section_name=parent_section_name, + trim_children=False) - def get_outline_as_list( - self, - root_section_name: Optional[str] = None, - add_hashtags: bool = False, - include_root: bool = True, - ) -> List[str]: + def get_outline_as_list(self, root_section_name: Optional[str] = None, add_hashtags: bool = False, + include_root: bool = True) -> List[str]: """ Get outline of the article as a list. @@ -375,7 +320,7 @@ def get_outline_as_list( ###section1.2 ##section2 article.get_outline_as_list("section1") returns [section1, section1.1, section1.2, section2] - + Returns: list of section and subsection names. """ @@ -389,14 +334,8 @@ def get_outline_as_list( result = [] def preorder_traverse(node, level): - prefix = ( - "#" * level if add_hashtags else "" - ) # Adjust level if excluding root - result.append( - f"{prefix} {node.section_name}".strip() - if add_hashtags - else node.section_name - ) + prefix = "#" * level if add_hashtags else "" # Adjust level if excluding root + result.append(f"{prefix} {node.section_name}".strip() if add_hashtags else node.section_name) for child in node.children: preorder_traverse(child, level + 1) @@ -411,7 +350,7 @@ def preorder_traverse(node, level): def to_string(self) -> str: """ Get outline of the article as a list. - + Returns: list of section and subsection names. """ @@ -437,9 +376,7 @@ def reorder_reference_index(self): def pre_order_find_index(node): if node is not None: if node.content is not None and node.content: - ref_indices.extend( - ArticleTextProcessing.parse_citation_indices(node.content) - ) + ref_indices.extend(ArticleTextProcessing.parse_citation_indices(node.content)) for child in node.children: pre_order_find_index(child) @@ -454,9 +391,7 @@ def pre_order_find_index(node): def pre_order_update_index(node): if node is not None: if node.content is not None and node.content: - node.content = ArticleTextProcessing.update_citation_index( - node.content, ref_index_mapping - ) + node.content = ArticleTextProcessing.update_citation_index(node.content, ref_index_mapping) for child in node.children: pre_order_update_index(child) @@ -507,18 +442,18 @@ def from_outline_str(cls, topic: str, outline_str: str): instance = cls(topic) if lines: - a = lines[0].startswith("#") and lines[0].replace("#", "").strip().lower() + a = lines[0].startswith('#') and lines[0].replace('#', '').strip().lower() b = topic.lower().replace("_", " ") - adjust_level = lines[0].startswith("#") and lines[0].replace( - "#", "" - ).strip().lower() == topic.lower().replace("_", " ") + adjust_level = lines[0].startswith('#') and lines[0].replace('#', + '').strip().lower() == topic.lower().replace( + "_", " ") if adjust_level: lines = lines[1:] node_stack = [(0, instance.root)] # Stack to keep track of (level, node) for line in lines: - level = line.count("#") - adjust_level - section_name = line.replace("#", "").strip() + level = line.count('#') - adjust_level + section_name = line.replace('#', '').strip() if section_name == topic: continue @@ -552,9 +487,7 @@ def from_string(cls, topic_name: str, article_text: str, references: dict): article = cls(topic_name=topic_name) article.insert_or_create_section(article_dict=article_dict) for url in list(references["url_to_info"]): - references["url_to_info"][url] = StormInformation.from_dict( - references["url_to_info"][url] - ) + references["url_to_info"][url] = StormInformation.from_dict(references["url_to_info"][url]) article.reference = references return article diff --git a/knowledge_storm/utils.py b/knowledge_storm/utils.py index d07d067c..5cf6f457 100644 --- a/knowledge_storm/utils.py +++ b/knowledge_storm/utils.py @@ -17,7 +17,7 @@ def load_api_key(toml_file_path): try: - with open(toml_file_path, "r") as file: + with open(toml_file_path, 'r') as file: data = toml.load(file) except FileNotFoundError: print(f"File not found: {toml_file_path}", file=sys.stderr) @@ -53,19 +53,19 @@ def limit_word_count_preserve_newline(input_string, max_word_count): """ word_count = 0 - limited_string = "" + limited_string = '' - for word in input_string.split("\n"): + for word in input_string.split('\n'): line_words = word.split() for lw in line_words: if word_count < max_word_count: - limited_string += lw + " " + limited_string += lw + ' ' word_count += 1 else: break if word_count >= max_word_count: break - limited_string = limited_string.strip() + "\n" + limited_string = limited_string.strip() + '\n' return limited_string.strip() @@ -83,7 +83,7 @@ def remove_citations(s): str: The string with all citation patterns removed. """ - return re.sub(r"\[\d+(?:,\s*\d+)*\]", "", s) + return re.sub(r'\[\d+(?:,\s*\d+)*\]', '', s) @staticmethod def parse_citation_indices(s): @@ -96,7 +96,7 @@ def parse_citation_indices(s): Returns: List[int]: A list of unique citation indexes extracted from the content, in the order they appear. """ - matches = re.findall(r"\[\d+\]", s) + matches = re.findall(r'\[\d+\]', s) return [int(index[1:-1]) for index in matches] @staticmethod @@ -117,21 +117,19 @@ def remove_uncompleted_sentences_with_citations(text): # Convert citations like [1, 2, 3] to [1][2][3]. def replace_with_individual_brackets(match): - numbers = match.group(1).split(", ") - return " ".join(f"[{n}]" for n in numbers) + numbers = match.group(1).split(', ') + return ' '.join(f'[{n}]' for n in numbers) # Deduplicate and sort individual groups of citations. def deduplicate_group(match): citations = match.group(0) - unique_citations = list(set(re.findall(r"\[\d+\]", citations))) - sorted_citations = sorted( - unique_citations, key=lambda x: int(x.strip("[]")) - ) + unique_citations = list(set(re.findall(r'\[\d+\]', citations))) + sorted_citations = sorted(unique_citations, key=lambda x: int(x.strip('[]'))) # Return the sorted unique citations as a string - return "".join(sorted_citations) + return ''.join(sorted_citations) - text = re.sub(r"\[([0-9, ]+)\]", replace_with_individual_brackets, text) - text = re.sub(r"(\[\d+\])+", deduplicate_group, text) + text = re.sub(r'\[([0-9, ]+)\]', replace_with_individual_brackets, text) + text = re.sub(r'(\[\d+\])+', deduplicate_group, text) # Deprecated: Remove sentence without proper ending punctuation and citations. # Split the text into sentences (including citations). @@ -152,38 +150,29 @@ def deduplicate_group(match): # combined_sentences += ' '.join(trailing_citations) # Regex pattern to match sentence endings, including optional citation markers. - eos_pattern = r"([.!?])\s*(\[\d+\])?\s*" + eos_pattern = r'([.!?])\s*(\[\d+\])?\s*' matches = list(re.finditer(eos_pattern, text)) if matches: last_match = matches[-1] - text = text[: last_match.end()].strip() + text = text[:last_match.end()].strip() return text @staticmethod def clean_up_citation(conv): for turn in conv.dlg_history: - turn.agent_utterance = turn.agent_utterance[ - : turn.agent_utterance.find("References:") - ] - turn.agent_utterance = turn.agent_utterance[ - : turn.agent_utterance.find("Sources:") - ] - turn.agent_utterance = turn.agent_utterance.replace("Answer:", "").strip() + turn.agent_utterance = turn.agent_utterance[:turn.agent_utterance.find('References:')] + turn.agent_utterance = turn.agent_utterance[:turn.agent_utterance.find('Sources:')] + turn.agent_utterance = turn.agent_utterance.replace('Answer:', '').strip() try: - max_ref_num = max( - [int(x) for x in re.findall(r"\[(\d+)\]", turn.agent_utterance)] - ) + max_ref_num = max([int(x) for x in re.findall(r'\[(\d+)\]', turn.agent_utterance)]) except Exception as e: max_ref_num = 0 if max_ref_num > len(turn.search_results): for i in range(len(turn.search_results), max_ref_num + 1): - turn.agent_utterance = turn.agent_utterance.replace(f"[{i}]", "") - turn.agent_utterance = ( - ArticleTextProcessing.remove_uncompleted_sentences_with_citations( - turn.agent_utterance - ) - ) + turn.agent_utterance = turn.agent_utterance.replace(f'[{i}]', '') + turn.agent_utterance = ArticleTextProcessing.remove_uncompleted_sentences_with_citations( + turn.agent_utterance) return conv @@ -192,46 +181,36 @@ def clean_up_outline(outline, topic=""): output_lines = [] current_level = 0 # To track the current section level - for line in outline.split("\n"): + for line in outline.split('\n'): stripped_line = line.strip() if topic != "" and f"# {topic.lower()}" in stripped_line.lower(): output_lines = [] # Check if the line is a section header - if stripped_line.startswith("#"): - current_level = stripped_line.count("#") + if stripped_line.startswith('#'): + current_level = stripped_line.count('#') output_lines.append(stripped_line) # Check if the line is a bullet point - elif stripped_line.startswith("-"): - subsection_header = ( - "#" * (current_level + 1) + " " + stripped_line[1:].strip() - ) + elif stripped_line.startswith('-'): + subsection_header = '#' * (current_level + 1) + ' ' + stripped_line[1:].strip() output_lines.append(subsection_header) - outline = "\n".join(output_lines) + outline = '\n'.join(output_lines) # Remove references. - outline = re.sub(r"#[#]? See also.*?(?=##|$)", "", outline, flags=re.DOTALL) - outline = re.sub(r"#[#]? See Also.*?(?=##|$)", "", outline, flags=re.DOTALL) - outline = re.sub(r"#[#]? Notes.*?(?=##|$)", "", outline, flags=re.DOTALL) - outline = re.sub(r"#[#]? References.*?(?=##|$)", "", outline, flags=re.DOTALL) - outline = re.sub( - r"#[#]? External links.*?(?=##|$)", "", outline, flags=re.DOTALL - ) - outline = re.sub( - r"#[#]? External Links.*?(?=##|$)", "", outline, flags=re.DOTALL - ) - outline = re.sub(r"#[#]? Bibliography.*?(?=##|$)", "", outline, flags=re.DOTALL) - outline = re.sub( - r"#[#]? Further reading*?(?=##|$)", "", outline, flags=re.DOTALL - ) - outline = re.sub( - r"#[#]? Further Reading*?(?=##|$)", "", outline, flags=re.DOTALL - ) - outline = re.sub(r"#[#]? Summary.*?(?=##|$)", "", outline, flags=re.DOTALL) - outline = re.sub(r"#[#]? Appendices.*?(?=##|$)", "", outline, flags=re.DOTALL) - outline = re.sub(r"#[#]? Appendix.*?(?=##|$)", "", outline, flags=re.DOTALL) + outline = re.sub(r"#[#]? See also.*?(?=##|$)", '', outline, flags=re.DOTALL) + outline = re.sub(r"#[#]? See Also.*?(?=##|$)", '', outline, flags=re.DOTALL) + outline = re.sub(r"#[#]? Notes.*?(?=##|$)", '', outline, flags=re.DOTALL) + outline = re.sub(r"#[#]? References.*?(?=##|$)", '', outline, flags=re.DOTALL) + outline = re.sub(r"#[#]? External links.*?(?=##|$)", '', outline, flags=re.DOTALL) + outline = re.sub(r"#[#]? External Links.*?(?=##|$)", '', outline, flags=re.DOTALL) + outline = re.sub(r"#[#]? Bibliography.*?(?=##|$)", '', outline, flags=re.DOTALL) + outline = re.sub(r"#[#]? Further reading*?(?=##|$)", '', outline, flags=re.DOTALL) + outline = re.sub(r"#[#]? Further Reading*?(?=##|$)", '', outline, flags=re.DOTALL) + outline = re.sub(r"#[#]? Summary.*?(?=##|$)", '', outline, flags=re.DOTALL) + outline = re.sub(r"#[#]? Appendices.*?(?=##|$)", '', outline, flags=re.DOTALL) + outline = re.sub(r"#[#]? Appendix.*?(?=##|$)", '', outline, flags=re.DOTALL) return outline @@ -242,40 +221,34 @@ def clean_up_section(text): 2. Deduplicate individual groups of citations. 3. Remove unnecessary summary.""" - paragraphs = text.split("\n") + paragraphs = text.split('\n') output_paragraphs = [] summary_sec_flag = False for p in paragraphs: p = p.strip() if len(p) == 0: continue - if not p.startswith("#"): + if not p.startswith('#'): p = ArticleTextProcessing.remove_uncompleted_sentences_with_citations(p) if summary_sec_flag: - if p.startswith("#"): + if p.startswith('#'): summary_sec_flag = False else: continue - if ( - p.startswith("Overall") - or p.startswith("In summary") - or p.startswith("In conclusion") - ): + if p.startswith('Overall') or p.startswith('In summary') or p.startswith('In conclusion'): continue - if "# Summary" in p or "# Conclusion" in p: + if "# Summary" in p or '# Conclusion' in p: summary_sec_flag = True continue output_paragraphs.append(p) - return "\n\n".join(output_paragraphs) # Join with '\n\n' for markdown format. + return '\n\n'.join(output_paragraphs) # Join with '\n\n' for markdown format. @staticmethod def update_citation_index(s, citation_map): """Update citation index in the string based on the citation map.""" for original_citation in citation_map: - s = s.replace( - f"[{original_citation}]", f"__PLACEHOLDER_{original_citation}__" - ) + s = s.replace(f"[{original_citation}]", f"__PLACEHOLDER_{original_citation}__") for original_citation, unify_citation in citation_map.items(): s = s.replace(f"__PLACEHOLDER_{original_citation}__", f"[{unify_citation}]") @@ -302,34 +275,34 @@ def parse_article_into_dict(input_string): A dictionary representing contains the section title as the key, and another dictionary as the value, which includes the 'content' and 'subsections' keys as described above. """ - lines = input_string.split("\n") + lines = input_string.split('\n') lines = [line for line in lines if line.strip()] - root = {"content": "", "subsections": {}} + root = {'content': '', 'subsections': {}} current_path = [(root, -1)] # (current_dict, level) for line in lines: - if line.startswith("#"): - level = line.count("#") - title = line.strip("# ").strip() - new_section = {"content": "", "subsections": {}} + if line.startswith('#'): + level = line.count('#') + title = line.strip('# ').strip() + new_section = {'content': '', 'subsections': {}} # Pop from stack until find the parent level while current_path and current_path[-1][1] >= level: current_path.pop() # Append new section to the nearest upper level's subsections - current_path[-1][0]["subsections"][title] = new_section + current_path[-1][0]['subsections'][title] = new_section current_path.append((new_section, level)) else: - current_path[-1][0]["content"] += line + "\n" + current_path[-1][0]['content'] += line + '\n' - return root["subsections"] + return root['subsections'] class FileIOHelper: @staticmethod def dump_json(obj, file_name, encoding="utf-8"): - with open(file_name, "w", encoding=encoding) as fw: + with open(file_name, 'w', encoding=encoding) as fw: json.dump(obj, fw, default=FileIOHelper.handle_non_serializable) @staticmethod @@ -338,27 +311,27 @@ def handle_non_serializable(obj): @staticmethod def load_json(file_name, encoding="utf-8"): - with open(file_name, "r", encoding=encoding) as fr: + with open(file_name, 'r', encoding=encoding) as fr: return json.load(fr) @staticmethod def write_str(s, path): - with open(path, "w") as f: + with open(path, 'w') as f: f.write(s) @staticmethod def load_str(path): - with open(path, "r") as f: - return "\n".join(f.readlines()) + with open(path, 'r') as f: + return '\n'.join(f.readlines()) @staticmethod def dump_pickle(obj, path): - with open(path, "wb") as f: + with open(path, 'wb') as f: pickle.dump(obj, f) @staticmethod def load_pickle(path): - with open(path, "rb") as f: + with open(path, 'rb') as f: return pickle.load(f) @@ -368,12 +341,7 @@ class WebPageHelper: Acknowledgement: Part of the code is adapted from https://github.com/stanford-oval/WikiChat project. """ - def __init__( - self, - min_char_count: int = 150, - snippet_chunk_size: int = 1000, - max_thread_num: int = 10, - ): + def __init__(self, min_char_count: int = 150, snippet_chunk_size: int = 1000, max_thread_num: int = 10): """ Args: min_char_count: Minimum character count for the article to be considered valid. @@ -414,9 +382,7 @@ def download_webpage(self, url: str): return None def urls_to_articles(self, urls: List[str]) -> Dict: - with concurrent.futures.ThreadPoolExecutor( - max_workers=self.max_thread_num - ) as executor: + with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_thread_num) as executor: htmls = list(executor.map(self.download_webpage, urls)) articles = {}