Skip to content

Commit

Permalink
Reset to upstream, add support for DeepSeek
Browse files Browse the repository at this point in the history
  • Loading branch information
rmcc3 committed Jul 19, 2024
1 parent 8751f05 commit f070579
Show file tree
Hide file tree
Showing 2 changed files with 239 additions and 0 deletions.
151 changes: 151 additions & 0 deletions examples/run_storm_wiki_deepseek.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
"""
STORM Wiki pipeline powered by DeepSeek models and You.com or Bing search engine.
You need to set up the following environment variables to run this script:
- DEEPSEEK_API_KEY: DeepSeek API key
- DEEPSEEK_API_BASE: DeepSeek API base URL (default is https://api.deepseek.com)
- YDC_API_KEY: You.com API key; or, BING_SEARCH_API_KEY: Bing Search API key
Output will be structured as below
args.output_dir/
topic_name/ # topic_name will follow convention of underscore-connected topic name w/o space and slash
conversation_log.json # Log of information-seeking conversation
raw_search_results.json # Raw search results from search engine
direct_gen_outline.txt # Outline directly generated with LLM's parametric knowledge
storm_gen_outline.txt # Outline refined with collected information
url_to_info.json # Sources that are used in the final article
storm_gen_article.txt # Final article generated
storm_gen_article_polished.txt # Polished final article (if args.do_polish_article is True)
"""

import os
import sys
import re
from argparse import ArgumentParser

from knowledge_storm import STORMWikiRunnerArguments, STORMWikiRunner, STORMWikiLMConfigs
from knowledge_storm.lm import DeepSeekModel
from knowledge_storm.rm import YouRM, BingSearch
from knowledge_storm.utils import load_api_key


def sanitize_topic(topic):
"""
Sanitize the topic name for use in file names.
Remove or replace characters that are not allowed in file names.
"""
# Replace spaces with underscores
topic = topic.replace(' ', '_')

# Remove any character that isn't alphanumeric, underscore, or hyphen
topic = re.sub(r'[^a-zA-Z0-9_-]', '', topic)

# Ensure the topic isn't empty after sanitization
if not topic:
topic = "unnamed_topic"

return topic


def main(args):
load_api_key(toml_file_path='secrets.toml')
lm_configs = STORMWikiLMConfigs()

# Ensure DEEPSEEK_API_KEY is set
if not os.getenv("DEEPSEEK_API_KEY"):
raise ValueError("DEEPSEEK_API_KEY environment variable is not set. Please set it in your secrets.toml file.")

deepseek_kwargs = {
'api_key': os.getenv("DEEPSEEK_API_KEY"),
'api_base': os.getenv("DEEPSEEK_API_BASE", "https://api.deepseek.com"),
'temperature': args.temperature,
'top_p': args.top_p,
}

# DeepSeek offers two main models: 'deepseek-chat' for general tasks and 'deepseek-coder' for coding tasks
# Users can choose the appropriate model based on their needs
conv_simulator_lm = DeepSeekModel(model=args.model, max_tokens=500, **deepseek_kwargs)
question_asker_lm = DeepSeekModel(model=args.model, max_tokens=500, **deepseek_kwargs)
outline_gen_lm = DeepSeekModel(model=args.model, max_tokens=400, **deepseek_kwargs)
article_gen_lm = DeepSeekModel(model=args.model, max_tokens=700, **deepseek_kwargs)
article_polish_lm = DeepSeekModel(model=args.model, max_tokens=4000, **deepseek_kwargs)

lm_configs.set_conv_simulator_lm(conv_simulator_lm)
lm_configs.set_question_asker_lm(question_asker_lm)
lm_configs.set_outline_gen_lm(outline_gen_lm)
lm_configs.set_article_gen_lm(article_gen_lm)
lm_configs.set_article_polish_lm(article_polish_lm)

engine_args = STORMWikiRunnerArguments(
output_dir=args.output_dir,
max_conv_turn=args.max_conv_turn,
max_perspective=args.max_perspective,
search_top_k=args.search_top_k,
max_thread_num=args.max_thread_num,
)

# STORM is a knowledge curation system which consumes information from the retrieval module.
# Currently, the information source is the Internet and we use search engine API as the retrieval module.
if args.retriever == 'bing':
rm = BingSearch(bing_search_api=os.getenv('BING_SEARCH_API_KEY'), k=engine_args.search_top_k)
elif args.retriever == 'you':
rm = YouRM(ydc_api_key=os.getenv('YDC_API_KEY'), k=engine_args.search_top_k)
else:
raise ValueError(f"Invalid retriever: {args.retriever}. Choose either 'bing' or 'you'.")

runner = STORMWikiRunner(engine_args, lm_configs, rm)

topic = input('Topic: ')
sanitized_topic = sanitize_topic(topic)
runner.run(
topic=sanitized_topic,
do_research=args.do_research,
do_generate_outline=args.do_generate_outline,
do_generate_article=args.do_generate_article,
do_polish_article=args.do_polish_article,
remove_duplicate=args.remove_duplicate,
)
runner.post_run()
runner.summary()


if __name__ == '__main__':
parser = ArgumentParser()
# global arguments
parser.add_argument('--output-dir', type=str, default='./results/deepseek',
help='Directory to store the outputs.')
parser.add_argument('--max-thread-num', type=int, default=3,
help='Maximum number of threads to use. The information seeking part and the article generation'
'part can speed up by using multiple threads. Consider reducing it if keep getting '
'"Exceed rate limit" error when calling LM API.')
parser.add_argument('--retriever', type=str, choices=['bing', 'you'], required=True,
help='The search engine API to use for retrieving information.')
parser.add_argument('--model', type=str, choices=['deepseek-chat', 'deepseek-coder'], default='deepseek-chat',
help='DeepSeek model to use. "deepseek-chat" for general tasks, "deepseek-coder" for coding tasks.')
parser.add_argument('--temperature', type=float, default=1.0,
help='Sampling temperature to use.')
parser.add_argument('--top_p', type=float, default=0.9,
help='Top-p sampling parameter.')
# stage of the pipeline
parser.add_argument('--do-research', action='store_true',
help='If True, simulate conversation to research the topic; otherwise, load the results.')
parser.add_argument('--do-generate-outline', action='store_true',
help='If True, generate an outline for the topic; otherwise, load the results.')
parser.add_argument('--do-generate-article', action='store_true',
help='If True, generate an article for the topic; otherwise, load the results.')
parser.add_argument('--do-polish-article', action='store_true',
help='If True, polish the article by adding a summarization section and (optionally) removing '
'duplicate content.')
# hyperparameters for the pre-writing stage
parser.add_argument('--max-conv-turn', type=int, default=3,
help='Maximum number of questions in conversational question asking.')
parser.add_argument('--max-perspective', type=int, default=3,
help='Maximum number of perspectives to consider in perspective-guided question asking.')
parser.add_argument('--search-top-k', type=int, default=3,
help='Top k search results to consider for each search query.')
# hyperparameters for the writing stage
parser.add_argument('--retrieve-top-k', type=int, default=3,
help='Top k collected references for each section title.')
parser.add_argument('--remove-duplicate', action='store_true',
help='If True, remove duplicate content from the article.')

main(parser.parse_args())
88 changes: 88 additions & 0 deletions knowledge_storm/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,94 @@ def __call__(

return completions

class DeepSeekModel(dspy.OpenAI):
"""A wrapper class for DeepSeek API, compatible with dspy.OpenAI."""

def __init__(
self,
model: str = "deepseek-chat",
api_key: Optional[str] = None,
api_base: str = "https://api.deepseek.com",
**kwargs
):
super().__init__(model=model, api_key=api_key, api_base=api_base, **kwargs)
self._token_usage_lock = threading.Lock()
self.prompt_tokens = 0
self.completion_tokens = 0
self.model = model
self.api_key = api_key or os.getenv("DEEPSEEK_API_KEY")
self.api_base = api_base
if not self.api_key:
raise ValueError("DeepSeek API key must be provided either as an argument or as an environment variable DEEPSEEK_API_KEY")

def log_usage(self, response):
"""Log the total tokens from the DeepSeek API response."""
usage_data = response.get('usage')
if usage_data:
with self._token_usage_lock:
self.prompt_tokens += usage_data.get('prompt_tokens', 0)
self.completion_tokens += usage_data.get('completion_tokens', 0)

def get_usage_and_reset(self):
"""Get the total tokens used and reset the token usage."""
usage = {
self.model:
{'prompt_tokens': self.prompt_tokens, 'completion_tokens': self.completion_tokens}
}
self.prompt_tokens = 0
self.completion_tokens = 0
return usage

@backoff.on_exception(
backoff.expo,
ERRORS,
max_time=1000,
on_backoff=backoff_hdlr,
giveup=giveup_hdlr,
)
def _create_completion(self, prompt: str, **kwargs):
"""Create a completion using the DeepSeek API."""
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}"
}
data = {
"model": self.model,
"messages": [{"role": "user", "content": prompt}],
**kwargs
}
response = requests.post(f"{self.api_base}/v1/chat/completions", headers=headers, json=data)
response.raise_for_status()
return response.json()

def __call__(
self,
prompt: str,
only_completed: bool = True,
return_sorted: bool = False,
**kwargs,
) -> list[dict[str, Any]]:
"""Call the DeepSeek API to generate completions."""
assert only_completed, "for now"
assert return_sorted is False, "for now"

response = self._create_completion(prompt, **kwargs)

# Log the token usage from the DeepSeek API response.
self.log_usage(response)

choices = response["choices"]
completions = [choice["message"]["content"] for choice in choices]

history = {
"prompt": prompt,
"response": response,
"kwargs": kwargs,
}
self.history.append(history)

return completions


class AzureOpenAIModel(dspy.AzureOpenAI):
"""A wrapper class for dspy.AzureOpenAI."""
Expand Down

0 comments on commit f070579

Please sign in to comment.