Skip to content

Commit

Permalink
Auto-sync-2024-05-13-16-07-47
Browse files Browse the repository at this point in the history
  • Loading branch information
shaoyijia authored May 13, 2024
1 parent 463634e commit 1b64b8a
Show file tree
Hide file tree
Showing 5 changed files with 196 additions and 14 deletions.
9 changes: 7 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

**Latest News** 🔥

- [2024/05] We add Bing Search support in [rm.py](src/rm.py). Test STORM with `GPT-4o` - we now configurate the article generation part in our demo using `GPT-4o` model.
- [2024/04] We release refactored version of STORM codebase! We define [interface](src/interface.py) for STORM pipeline and reimplement STORM-wiki (check out [`src/storm_wiki`](src/storm_wiki)) to demonstrate how to instantiate the pipeline. We provide API to support customization of different language models and retrieval/search integration.

## Overview [(Try STORM now!)](https://storm.genie.stanford.edu/)
Expand Down Expand Up @@ -78,8 +79,9 @@ Currently, we provide example scripts under [`examples`](examples) to demonstrat
**To run STORM with `gpt` family models**: Make sure you have set up the OpenAI API key and run the following command.

```
python scripts/run_storm_wiki_gpt.py \
python examples/run_storm_wiki_gpt.py \
--output_dir $OUTPUT_DIR \
--retriever you \
--do-research \
--do-generate-outline \
--do-generate-article \
Expand All @@ -93,10 +95,11 @@ python scripts/run_storm_wiki_gpt.py \
**To run STORM with `mistral` family models on local VLLM server**: have a VLLM server running with the `Mistral-7B-Instruct-v0.2` model and run the following command.
```
python scripts/run_storm_wiki_mistral.py \
python examples/run_storm_wiki_mistral.py \
--url $URL \
--port $PORT \
--output_dir $OUTPUT_DIR \
--retriever you \
--do-research \
--do-generate-outline \
--do-generate-article \
Expand Down Expand Up @@ -126,6 +129,8 @@ The interface for each module is defined in `src/interface.py`, while their impl
As a knowledge curation engine, STORM grabs information from the Retriever module. The interface for the Retriever module is defined in [`src/interface.py`](src/interface.py). Please consult the interface documentation if you plan to create a new instance or replace the default search engine API. By default, STORM utilizes the You.com search engine API (see `YouRM` in [`src/rm.py`](src/rm.py)).
:new: [2024/05] We test STORM with [Bing Search](https://learn.microsoft.com/en-us/bing/search-apis/bing-web-search/reference/endpoints). See `BingSearch` in [`src/rm.py`](src/rm.py) for the configuration and you can specify `--retriever bing` to use Bing Search in our [example scripts](examples).
:star2: **PRs for integrating more search engines/retrievers are highly appreciated!**
### Customization of Language Models
Expand Down
103 changes: 98 additions & 5 deletions src/rm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import dspy
import requests

from utils import WebPageHelper


class YouRM(dspy.Retrieve):
def __init__(self, ydc_api_key=None, k=3, is_valid_source: Callable = None):
Expand All @@ -18,15 +20,18 @@ def __init__(self, ydc_api_key=None, k=3, is_valid_source: Callable = None):
self.usage = 0

# If not None, is_valid_source shall be a function that takes a URL and returns a boolean.
self.is_valid_source = is_valid_source
if is_valid_source:
self.is_valid_source = is_valid_source
else:
self.is_valid_source = lambda x: True

def get_usage_and_reset(self):
usage = self.usage
self.usage = 0

return {'YouRM': usage}

def forward(self, query_or_queries: Union[str, List[str]], exclude_urls: List[str]):
def forward(self, query_or_queries: Union[str, List[str]], exclude_urls: List[str] = []):
"""Search with You.com for self.k top passages for query or queries
Args:
Expand All @@ -53,14 +58,102 @@ def forward(self, query_or_queries: Union[str, List[str]], exclude_urls: List[st

authoritative_results = []
for r in results['hits']:
if self.is_valid_source is None or self.is_valid_source(r['url']):
if self.is_valid_source(r['url']) and r['url'] not in exclude_urls:
authoritative_results.append(r)
if 'hits' in results:
collected_results.extend(authoritative_results[:self.k])
except Exception as e:
logging.error(f'Error occurs when searching query {query}: {e}')

if exclude_urls:
collected_results = [r for r in collected_results if r['url'] not in exclude_urls]
return collected_results


class BingSearch(dspy.Retrieve):
def __init__(self, bing_search_api_key=None, k=3, is_valid_source: Callable = None,
min_char_count: int = 150, snippet_chunk_size: int = 1000, webpage_helper_max_threads=10,
mkt='en-US', language='en', **kwargs):
"""
Params:
min_char_count: Minimum character count for the article to be considered valid.
snippet_chunk_size: Maximum character count for each snippet.
webpage_helper_max_threads: Maximum number of threads to use for webpage helper.
mkt, language, **kwargs: Bing search API parameters.
- Reference: https://learn.microsoft.com/en-us/bing/search-apis/bing-web-search/reference/query-parameters
"""
super().__init__(k=k)
if not bing_search_api_key and not os.environ.get("BING_SEARCH_API_KEY"):
raise RuntimeError(
"You must supply bing_search_subscription_key or set environment variable BING_SEARCH_API_KEY")
elif bing_search_api_key:
self.bing_api_key = bing_search_api_key
else:
self.bing_api_key = os.environ["BING_SEARCH_API_KEY"]
self.endpoint = "https://api.bing.microsoft.com/v7.0/search"
self.params = {
'mkt': mkt,
"setLang": language,
"count": k,
**kwargs
}
self.webpage_helper = WebPageHelper(
min_char_count=min_char_count,
snippet_chunk_size=snippet_chunk_size,
max_thread_num=webpage_helper_max_threads
)
self.usage = 0

# If not None, is_valid_source shall be a function that takes a URL and returns a boolean.
if is_valid_source:
self.is_valid_source = is_valid_source
else:
self.is_valid_source = lambda x: True

def get_usage_and_reset(self):
usage = self.usage
self.usage = 0

return {'BingSearch': usage}

def forward(self, query_or_queries: Union[str, List[str]], exclude_urls: List[str] = []):
"""Search with Bing for self.k top passages for query or queries
Args:
query_or_queries (Union[str, List[str]]): The query or queries to search for.
exclude_urls (List[str]): A list of urls to exclude from the search results.
Returns:
a list of Dicts, each dict has keys of 'description', 'snippets' (list of strings), 'title', 'url'
"""
queries = (
[query_or_queries]
if isinstance(query_or_queries, str)
else query_or_queries
)
self.usage += len(queries)

url_to_results = {}

headers = {"Ocp-Apim-Subscription-Key": self.bing_api_key}

for query in queries:
try:
results = requests.get(
self.endpoint,
headers=headers,
params={**self.params, 'q': query}
).json()

for d in results['webPages']['value']:
if self.is_valid_source(d['url']) and d['url'] not in exclude_urls:
url_to_results[d['url']] = {'url': d['url'], 'title': d['name'], 'description': d['snippet']}
except Exception as e:
logging.error(f'Error occurs when searching query {query}: {e}')

valid_url_to_snippets = self.webpage_helper.urls_to_snippets(list(url_to_results.keys()))
collected_results = []
for url in valid_url_to_snippets:
r = url_to_results[url]
r['snippets'] = valid_url_to_snippets[url]['snippets']
collected_results.append(r)

return collected_results
5 changes: 3 additions & 2 deletions src/storm_wiki/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,12 +138,13 @@ class STORMWikiRunner(Engine):

def __init__(self,
args: STORMWikiRunnerArguments,
lm_configs: STORMWikiLMConfigs):
lm_configs: STORMWikiLMConfigs,
rm):
super().__init__(lm_configs=lm_configs)
self.args = args
self.lm_configs = lm_configs

self.retriever = StormRetriever(k=self.args.retrieve_top_k)
self.retriever = StormRetriever(rm=rm, k=self.args.retrieve_top_k)
storm_persona_generator = StormPersonaGenerator(self.lm_configs.question_asker_lm)
self.storm_knowledge_curation_module = StormKnowledgeCurationModule(
retriever=self.retriever,
Expand Down
10 changes: 6 additions & 4 deletions src/storm_wiki/modules/retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Union, List
from urllib.parse import urlparse

import dspy
import storm_wiki.modules.storm_dataclass as storm_dataclass
from interface import Retriever, Information
from rm import YouRM
Expand All @@ -28,13 +29,14 @@ def is_valid_wikipedia_source(url):


class StormRetriever(Retriever):
def __init__(self, ydc_api_key=None, k=3):
def __init__(self, rm: dspy.Retrieve, k=3):
super().__init__(search_top_k=k)
self.you_rm = YouRM(ydc_api_key=ydc_api_key, k=self.search_top_k, is_valid_source=is_valid_wikipedia_source)
self._rm = rm
if hasattr(rm, 'is_valid_source'):
rm.is_valid_source = is_valid_wikipedia_source

def retrieve(self, query: Union[str, List[str]], exclude_urls: List[str] = []) -> List[Information]:
self.you_rm.k = self.search_top_k
retrieved_data_list = self.you_rm(query_or_queries=query, exclude_urls=exclude_urls)
retrieved_data_list = self._rm(query_or_queries=query, exclude_urls=exclude_urls)
for data in retrieved_data_list:
for i in range(len(data['snippets'])):
# STORM generate the article with citations. We do not consider multi-hop citations.
Expand Down
83 changes: 82 additions & 1 deletion src/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
import concurrent.futures
import json
import os
import pickle
import re
import sys
from typing import List, Dict

import httpx
import toml
from langchain_text_splitters import RecursiveCharacterTextSplitter
from trafilatura import extract


def load_api_key(toml_file_path):
Expand All @@ -21,9 +26,11 @@ def load_api_key(toml_file_path):
for key, value in data.items():
os.environ[key] = str(value)

def makeStringRed(message):

def makeStringRed(message):
return f"\033[91m {message}\033[00m"


class ArticleTextProcessing:
@staticmethod
def limit_word_count_preserve_newline(input_string, max_word_count):
Expand Down Expand Up @@ -323,3 +330,77 @@ def dump_pickle(obj, path):
def load_pickle(path):
with open(path, 'rb') as f:
return pickle.load(f)


class WebPageHelper:
"""Helper class to process web pages.
Acknowledgement: Part of the code is adapted from https://github.com/stanford-oval/WikiChat project.
"""

def __init__(self, min_char_count: int = 150, snippet_chunk_size: int = 1000, max_thread_num: int = 10):
"""
Args:
min_char_count: Minimum character count for the article to be considered valid.
snippet_chunk_size: Maximum character count for each snippet.
max_thread_num: Maximum number of threads to use for concurrent requests (e.g., downloading webpages).
"""
self.httpx_client = httpx.Client(verify=False)
self.min_char_count = min_char_count
self.max_thread_num = max_thread_num
self.text_splitter = RecursiveCharacterTextSplitter(
chunk_size=snippet_chunk_size,
chunk_overlap=0,
length_function=len,
is_separator_regex=False,
separators=[
"\n\n",
"\n",
".",
"\uff0e", # Fullwidth full stop
"\u3002", # Ideographic full stop
",",
"\uff0c", # Fullwidth comma
"\u3001", # Ideographic comma
" ",
"\u200B", # Zero-width space
"",
],
)

def download_webpage(self, url: str):
try:
res = self.httpx_client.get(url, timeout=4)
if res.status_code >= 400:
res.raise_for_status()
return res.content
except httpx.HTTPError as exc:
print(f"Error while requesting {exc.request.url!r} - {exc!r}")
return None

def urls_to_articles(self, urls: List[str]) -> Dict:
with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_thread_num) as executor:
htmls = list(executor.map(self.download_webpage, urls))

articles = {}

for h, u in zip(htmls, urls):
if h is None:
continue
article_text = extract(
h,
include_tables=False,
include_comments=False,
output_format="text",
)
if article_text is not None and len(article_text) > self.min_char_count:
articles[u] = {"text": article_text}

return articles

def urls_to_snippets(self, urls: List[str]) -> Dict:
articles = self.urls_to_articles(urls)
for u in articles:
articles[u]["snippets"] = self.text_splitter.split_text(articles[u]["text"])

return articles

0 comments on commit 1b64b8a

Please sign in to comment.