Skip to content

Commit

Permalink
Update example scripts.
Browse files Browse the repository at this point in the history
  • Loading branch information
shaoyijia committed Jul 15, 2024
1 parent 97ca850 commit 708fc04
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 40 deletions.
12 changes: 5 additions & 7 deletions examples/run_storm_wiki_claude.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,12 @@
"""

import os
import sys
from argparse import ArgumentParser

sys.path.append('./src')
from lm import ClaudeModel
from rm import YouRM, BingSearch
from storm_wiki.engine import STORMWikiRunnerArguments, STORMWikiRunner, STORMWikiLMConfigs
from utils import load_api_key
from knowledge_storm import STORMWikiRunnerArguments, STORMWikiRunner, STORMWikiLMConfigs
from knowledge_storm.lm import ClaudeModel
from knowledge_storm.rm import YouRM, BingSearch
from knowledge_storm.utils import load_api_key


def main(args):
Expand Down Expand Up @@ -116,4 +114,4 @@ def main(args):
parser.add_argument('--remove-duplicate', action='store_true',
help='If True, remove duplicate content from the article.')

main(parser.parse_args())
main(parser.parse_args())
34 changes: 19 additions & 15 deletions examples/run_storm_wiki_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,38 +20,42 @@
"""

import os
import sys
from argparse import ArgumentParser

sys.path.append('./src')
from lm import OpenAIModel
from rm import YouRM, BingSearch
from storm_wiki.engine import STORMWikiRunnerArguments, STORMWikiRunner, STORMWikiLMConfigs
from utils import load_api_key
from knowledge_storm import STORMWikiRunnerArguments, STORMWikiRunner, STORMWikiLMConfigs
from knowledge_storm.lm import OpenAIModel, AzureOpenAIModel
from knowledge_storm.rm import YouRM, BingSearch
from knowledge_storm.utils import load_api_key


def main(args):
load_api_key(toml_file_path='secrets.toml')
lm_configs = STORMWikiLMConfigs()
openai_kwargs = {
'api_key': os.getenv("OPENAI_API_KEY"),
'api_provider': os.getenv('OPENAI_API_TYPE'),
'temperature': 1.0,
'top_p': 0.9,
'api_base': os.getenv('AZURE_API_BASE'),
'api_version': os.getenv('AZURE_API_VERSION'),
}

ModelClass = OpenAIModel if os.getenv('OPENAI_API_TYPE') == 'openai' else AzureOpenAIModel
# If you are using Azure service, make sure the model name matches your own deployed model name.
# The default name here is only used for demonstration and may not match your case.
gpt_35_model_name = 'gpt-3.5-turbo' if os.getenv('OPENAI_API_TYPE') == 'openai' else 'gpt-35-turbo'
gpt_4_model_name = 'gpt-4o'
if os.getenv('OPENAI_API_TYPE') == 'azure':
openai_kwargs['api_base'] = os.getenv('AZURE_API_BASE')
openai_kwargs['api_version'] = os.getenv('AZURE_API_VERSION')

# STORM is a LM system so different components can be powered by different models.
# For a good balance between cost and quality, you can choose a cheaper/faster model for conv_simulator_lm
# which is used to split queries, synthesize answers in the conversation. We recommend using stronger models
# for outline_gen_lm which is responsible for organizing the collected information, and article_gen_lm
# which is responsible for generating sections with citations.
conv_simulator_lm = OpenAIModel(model='gpt-3.5-turbo', max_tokens=500, **openai_kwargs)
question_asker_lm = OpenAIModel(model='gpt-3.5-turbo', max_tokens=500, **openai_kwargs)
outline_gen_lm = OpenAIModel(model='gpt-4-0125-preview', max_tokens=400, **openai_kwargs)
article_gen_lm = OpenAIModel(model='gpt-4-0125-preview', max_tokens=700, **openai_kwargs)
article_polish_lm = OpenAIModel(model='gpt-4-0125-preview', max_tokens=4000, **openai_kwargs)
conv_simulator_lm = ModelClass(model=gpt_35_model_name, max_tokens=500, **openai_kwargs)
question_asker_lm = ModelClass(model=gpt_35_model_name, max_tokens=500, **openai_kwargs)
outline_gen_lm = ModelClass(model=gpt_4_model_name, max_tokens=400, **openai_kwargs)
article_gen_lm = ModelClass(model=gpt_4_model_name, max_tokens=700, **openai_kwargs)
article_polish_lm = ModelClass(model=gpt_4_model_name, max_tokens=4000, **openai_kwargs)

lm_configs.set_conv_simulator_lm(conv_simulator_lm)
lm_configs.set_question_asker_lm(question_asker_lm)
Expand Down Expand Up @@ -122,4 +126,4 @@ def main(args):
parser.add_argument('--remove-duplicate', action='store_true',
help='If True, remove duplicate content from the article.')

main(parser.parse_args())
main(parser.parse_args())
29 changes: 18 additions & 11 deletions examples/run_storm_wiki_gpt_with_VectorRM.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,10 @@
import sys
from argparse import ArgumentParser

sys.path.append('./src')
from storm_wiki.engine import STORMWikiRunnerArguments, STORMWikiRunner, STORMWikiLMConfigs
from rm import VectorRM
from lm import OpenAIModel
from utils import load_api_key
from knowledge_storm import STORMWikiRunnerArguments, STORMWikiRunner, STORMWikiLMConfigs
from knowledge_storm.rm import VectorRM
from knowledge_storm.lm import OpenAIModel, AzureOpenAIModel
from knowledge_storm.utils import load_api_key


def main(args):
Expand All @@ -45,21 +44,29 @@ def main(args):
engine_lm_configs = STORMWikiLMConfigs()
openai_kwargs = {
'api_key': os.getenv("OPENAI_API_KEY"),
'api_provider': os.getenv('OPENAI_API_TYPE'),
'temperature': 1.0,
'top_p': 0.9,
}

ModelClass = OpenAIModel if os.getenv('OPENAI_API_TYPE') == 'openai' else AzureOpenAIModel
# If you are using Azure service, make sure the model name matches your own deployed model name.
# The default name here is only used for demonstration and may not match your case.
gpt_35_model_name = 'gpt-3.5-turbo' if os.getenv('OPENAI_API_TYPE') == 'openai' else 'gpt-35-turbo'
gpt_4_model_name = 'gpt-4o'
if os.getenv('OPENAI_API_TYPE') == 'azure':
openai_kwargs['api_base'] = os.getenv('AZURE_API_BASE')
openai_kwargs['api_version'] = os.getenv('AZURE_API_VERSION')

# STORM is a LM system so different components can be powered by different models.
# For a good balance between cost and quality, you can choose a cheaper/faster model for conv_simulator_lm
# which is used to split queries, synthesize answers in the conversation. We recommend using stronger models
# for outline_gen_lm which is responsible for organizing the collected information, and article_gen_lm
# which is responsible for generating sections with citations.
conv_simulator_lm = OpenAIModel(model='gpt-3.5-turbo', max_tokens=500, **openai_kwargs)
question_asker_lm = OpenAIModel(model='gpt-3.5-turbo', max_tokens=500, **openai_kwargs)
outline_gen_lm = OpenAIModel(model='gpt-4-0125-preview', max_tokens=400, **openai_kwargs)
article_gen_lm = OpenAIModel(model='gpt-4-0125-preview', max_tokens=700, **openai_kwargs)
article_polish_lm = OpenAIModel(model='gpt-4-0125-preview', max_tokens=4000, **openai_kwargs)
conv_simulator_lm = ModelClass(model=gpt_35_model_name, max_tokens=500, **openai_kwargs)
question_asker_lm = ModelClass(model=gpt_35_model_name, max_tokens=500, **openai_kwargs)
outline_gen_lm = ModelClass(model=gpt_4_model_name, max_tokens=400, **openai_kwargs)
article_gen_lm = ModelClass(model=gpt_4_model_name, max_tokens=700, **openai_kwargs)
article_polish_lm = ModelClass(model=gpt_4_model_name, max_tokens=4000, **openai_kwargs)

engine_lm_configs.set_conv_simulator_lm(conv_simulator_lm)
engine_lm_configs.set_question_asker_lm(question_asker_lm)
Expand Down
12 changes: 5 additions & 7 deletions examples/run_storm_wiki_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,14 @@
storm_gen_article_polished.txt # Polished final article (if args.do_polish_article is True)
"""
import os
import sys
from argparse import ArgumentParser

from dspy import Example

sys.path.append('./src')
from lm import VLLMClient
from rm import YouRM, BingSearch
from storm_wiki.engine import STORMWikiRunnerArguments, STORMWikiRunner, STORMWikiLMConfigs
from utils import load_api_key
from knowledge_storm import STORMWikiRunnerArguments, STORMWikiRunner, STORMWikiLMConfigs
from knowledge_storm.lm import VLLMClient
from knowledge_storm.rm import YouRM, BingSearch
from knowledge_storm.utils import load_api_key


def main(args):
Expand Down Expand Up @@ -174,4 +172,4 @@ def main(args):
parser.add_argument('--remove-duplicate', action='store_true',
help='If True, remove duplicate content from the article.')

main(parser.parse_args())
main(parser.parse_args())

0 comments on commit 708fc04

Please sign in to comment.