Skip to content

Commit

Permalink
Reverted back everything except requirements.txt, the new example fil…
Browse files Browse the repository at this point in the history
…e, and rm.py
  • Loading branch information
zenith110 committed Jul 28, 2024
1 parent 0b5a563 commit 85a1e6e
Show file tree
Hide file tree
Showing 12 changed files with 589 additions and 1,024 deletions.
2 changes: 1 addition & 1 deletion knowledge_storm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .storm_wiki.engine import (
STORMWikiLMConfigs,
STORMWikiRunnerArguments,
STORMWikiRunner,
STORMWikiRunner
)
69 changes: 22 additions & 47 deletions knowledge_storm/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@
from collections import OrderedDict
from typing import Dict, List, Optional, Union

logging.basicConfig(
level=logging.INFO, format="%(name)s : %(levelname)-8s : %(message)s"
)
logging.basicConfig(level=logging.INFO, format='%(name)s : %(levelname)-8s : %(message)s')
logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -72,9 +70,7 @@ class Article(ABC):
def __init__(self, topic_name):
self.root = ArticleSectionNode(topic_name)

def find_section(
self, node: ArticleSectionNode, name: str
) -> Optional[ArticleSectionNode]:
def find_section(self, node: ArticleSectionNode, name: str) -> Optional[ArticleSectionNode]:
"""
Return the node of the section given the section name.
Expand Down Expand Up @@ -156,9 +152,7 @@ def prune_empty_nodes(self, node=None):
if node is None:
node = self.root

node.children[:] = [
child for child in node.children if self.prune_empty_nodes(child)
]
node.children[:] = [child for child in node.children if self.prune_empty_nodes(child)]

if (node.content is None or node.content == "") and not node.children:
return None
Expand All @@ -184,9 +178,7 @@ def update_search_top_k(self, k):
def collect_and_reset_rm_usage(self):
combined_usage = []
for attr_name in self.__dict__:
if "_rm" in attr_name and hasattr(
getattr(self, attr_name), "get_usage_and_reset"
):
if '_rm' in attr_name and hasattr(getattr(self, attr_name), 'get_usage_and_reset'):
combined_usage.append(getattr(self, attr_name).get_usage_and_reset())

name_to_usage = {}
Expand Down Expand Up @@ -248,9 +240,7 @@ class OutlineGenerationModule(ABC):
"""

@abstractmethod
def generate_outline(
self, topic: str, information_table: InformationTable, **kwargs
) -> Article:
def generate_outline(self, topic: str, information_table: InformationTable, **kwargs) -> Article:
"""
Generate outline for the article. Required arguments include:
topic: the topic of interest
Expand All @@ -273,13 +263,11 @@ class ArticleGenerationModule(ABC):
"""

@abstractmethod
def generate_article(
self,
topic: str,
information_table: InformationTable,
article_with_outline: Article,
**kwargs,
) -> Article:
def generate_article(self,
topic: str,
information_table: InformationTable,
article_with_outline: Article,
**kwargs) -> Article:
"""
Generate article. Required arguments include:
topic: the topic of interest
Expand Down Expand Up @@ -324,23 +312,22 @@ def wrapper(self, *args, **kwargs):
class LMConfigs(ABC):
"""Abstract base class for language model configurations of the knowledge curation engine.
The language model used for each part should be declared with a suffix '_lm' in the attribute name.
"""
The language model used for each part should be declared with a suffix '_lm' in the attribute name."""

def __init__(self):
pass

def init_check(self):
for attr_name in self.__dict__:
if "_lm" in attr_name and getattr(self, attr_name) is None:
if '_lm' in attr_name and getattr(self, attr_name) is None:
logging.warning(
f"Language model for {attr_name} is not initialized. Please call set_{attr_name}()"
)

def collect_and_reset_lm_history(self):
history = []
for attr_name in self.__dict__:
if "_lm" in attr_name and hasattr(getattr(self, attr_name), "history"):
if '_lm' in attr_name and hasattr(getattr(self, attr_name), 'history'):
history.extend(getattr(self, attr_name).history)
getattr(self, attr_name).history = []

Expand All @@ -349,9 +336,7 @@ def collect_and_reset_lm_history(self):
def collect_and_reset_lm_usage(self):
combined_usage = []
for attr_name in self.__dict__:
if "_lm" in attr_name and hasattr(
getattr(self, attr_name), "get_usage_and_reset"
):
if '_lm' in attr_name and hasattr(getattr(self, attr_name), 'get_usage_and_reset'):
combined_usage.append(getattr(self, attr_name).get_usage_and_reset())

model_name_to_usage = {}
Expand All @@ -360,22 +345,17 @@ def collect_and_reset_lm_usage(self):
if model_name not in model_name_to_usage:
model_name_to_usage[model_name] = tokens
else:
model_name_to_usage[model_name]["prompt_tokens"] += tokens[
"prompt_tokens"
]
model_name_to_usage[model_name]["completion_tokens"] += tokens[
"completion_tokens"
]
model_name_to_usage[model_name]['prompt_tokens'] += tokens['prompt_tokens']
model_name_to_usage[model_name]['completion_tokens'] += tokens['completion_tokens']

return model_name_to_usage

def log(self):

return OrderedDict(
{
attr_name: getattr(self, attr_name).kwargs
for attr_name in self.__dict__
if "_lm" in attr_name and hasattr(getattr(self, attr_name), "kwargs")
attr_name: getattr(self, attr_name).kwargs for attr_name in self.__dict__ if
'_lm' in attr_name and hasattr(getattr(self, attr_name), 'kwargs')
}
)

Expand All @@ -399,21 +379,16 @@ def wrapper(*args, **kwargs):
self.time[func.__name__] = execution_time
logger.info(f"{func.__name__} executed in {execution_time:.4f} seconds")
self.lm_cost[func.__name__] = self.lm_configs.collect_and_reset_lm_usage()
if hasattr(self, "retriever"):
self.rm_cost[func.__name__] = (
self.retriever.collect_and_reset_rm_usage()
)
if hasattr(self, 'retriever'):
self.rm_cost[func.__name__] = self.retriever.collect_and_reset_rm_usage()
return result

return wrapper

def apply_decorators(self):
"""Apply decorators to methods that need them."""
methods_to_decorate = [
method_name
for method_name in dir(self)
if callable(getattr(self, method_name)) and method_name.startswith("run_")
]
methods_to_decorate = [method_name for method_name in dir(self)
if callable(getattr(self, method_name)) and method_name.startswith('run_')]
for method_name in methods_to_decorate:
original_method = getattr(self, method_name)
decorated_method = self.log_execution_time_and_lm_rm_usage(original_method)
Expand Down
Loading

0 comments on commit 85a1e6e

Please sign in to comment.