forked from stanford-oval/storm
-
Notifications
You must be signed in to change notification settings - Fork 0
/
rm.py
66 lines (54 loc) · 2.38 KB
/
rm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import logging
import os
from typing import Callable, Union, List
import dspy
import requests
class YouRM(dspy.Retrieve):
def __init__(self, ydc_api_key=None, k=3, is_valid_source: Callable = None):
super().__init__(k=k)
if not ydc_api_key and not os.environ.get("YDC_API_KEY"):
raise RuntimeError("You must supply ydc_api_key or set environment variable YDC_API_KEY")
elif ydc_api_key:
self.ydc_api_key = ydc_api_key
else:
self.ydc_api_key = os.environ["YDC_API_KEY"]
self.usage = 0
# If not None, is_valid_source shall be a function that takes a URL and returns a boolean.
self.is_valid_source = is_valid_source
def get_usage_and_reset(self):
usage = self.usage
self.usage = 0
return {'YouRM': usage}
def forward(self, query_or_queries: Union[str, List[str]], exclude_urls: List[str]):
"""Search with You.com for self.k top passages for query or queries
Args:
query_or_queries (Union[str, List[str]]): The query or queries to search for.
exclude_urls (List[str]): A list of urls to exclude from the search results.
Returns:
a list of Dicts, each dict has keys of 'description', 'snippets' (list of strings), 'title', 'url'
"""
queries = (
[query_or_queries]
if isinstance(query_or_queries, str)
else query_or_queries
)
self.usage += len(queries)
collected_results = []
for query in queries:
try:
headers = {"X-API-Key": self.ydc_api_key}
results = requests.get(
f"https://api.ydc-index.io/search?query={query}",
headers=headers,
).json()
authoritative_results = []
for r in results['hits']:
if self.is_valid_source is None or self.is_valid_source(r['url']):
authoritative_results.append(r)
if 'hits' in results:
collected_results.extend(authoritative_results[:self.k])
except Exception as e:
logging.error(f'Error occurs when searching query {query}: {e}')
if exclude_urls:
collected_results = [r for r in collected_results if r['url'] not in exclude_urls]
return collected_results