Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
Yangqing committed Jan 25, 2024
1 parent 31bb456 commit d6889a4
Showing 1 changed file with 78 additions and 84 deletions.
162 changes: 78 additions & 84 deletions search_with_lepton.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from loguru import logger

import leptonai
from leptonai import Client
from leptonai.kv import KV
from leptonai.photon import HTTPException, Photon, StaticFiles
from leptonai.photon.types import to_bool
Expand All @@ -23,6 +24,20 @@
# Constant values for the RAG model.
################################################################################

# Search engine related. You don't really need to change this.
BING_SEARCH_V7_ENDPOINT = "https://api.bing.microsoft.com/v7.0/search"
BING_MKT = "en-US"
GOOGLE_SEARCH_ENDPOINT = "https://customsearch.googleapis.com/customsearch/v1"

# Specify the number of references from the search engine you want to use.
# 8 is usually a good number.
REFERENCE_COUNT = 8

# Specify the default timeout for the search engine. If the search engine
# does not respond within this time, we will return an error.
DEFAULT_SEARCH_ENGINE_TIMEOUT = 5


# If the user did not provide a query, we will use this default query.
_default_query = "Who said 'live long and prosper'?"

Expand Down Expand Up @@ -76,24 +91,16 @@
"""


def search_with_bing(
query,
endpoint,
bing_mkt,
subscription_key,
default_search_engine_timeout,
reference_count,
):
def search_with_bing(query: str, subscription_key: str):
"""
Search with bing and return the contexts.
"""
params = {"q": query, "mkt": bing_mkt}
logger.info(f"{endpoint} {params} {subscription_key}")
params = {"q": query, "mkt": BING_MKT}
response = requests.get(
endpoint,
BING_SEARCH_V7_ENDPOINT,
headers={"Ocp-Apim-Subscription-Key": subscription_key},
params=params,
timeout=default_search_engine_timeout,
timeout=DEFAULT_SEARCH_ENGINE_TIMEOUT,
)
if not response.ok:
logger.error(f"{response.status_code} {response.text}")
Expand All @@ -103,35 +110,25 @@ def search_with_bing(
)
json_content = response.json()
try:
contexts = json_content["webPages"]["value"][:reference_count]
contexts = json_content["webPages"]["value"][:REFERENCE_COUNT]
except KeyError:
logger.error(f"Error encountered: {json_content}")
raise HTTPException(
status_code=500,
detail="Search engine error.",
)
return []
return contexts


def search_with_google(
query,
endpoint,
subscription_key,
cx,
default_search_engine_timeout,
reference_count,
):
def search_with_google(query: str, subscription_key: str, cx: str):
"""
Search with google and return the contexts.
"""
params = {
"key": subscription_key,
"cx": cx,
"q": query,
"num": reference_count,
"num": REFERENCE_COUNT,
}
response = requests.get(
endpoint, params=params, timeout=default_search_engine_timeout
GOOGLE_SEARCH_ENDPOINT, params=params, timeout=DEFAULT_SEARCH_ENGINE_TIMEOUT
)
if not response.ok:
logger.error(f"{response.status_code} {response.text}")
Expand All @@ -141,13 +138,10 @@ def search_with_google(
)
json_content = response.json()
try:
contexts = json_content["items"][:reference_count]
contexts = json_content["items"][:REFERENCE_COUNT]
except KeyError:
logger.error(f"Error encountered: {json_content}")
raise HTTPException(
status_code=500,
detail="Search engine error.",
)
return []
return contexts


Expand All @@ -174,24 +168,17 @@ class RAG(Photon):
"resource_shape": "cpu.small",
# You most likely don't need to change this.
"env": {
# Choose the backend. Currently, we support BING and GOOGLE.
"BACKEND": "BING",
# If you are using BING, you should specify the bing endpoint and mkt.
# If you do not know what this is, you can leave it as is.
"BING_SEARCH_V7_ENDPOINT": "https://api.bing.microsoft.com/",
"BING_MKT": "en-US",
# If you are using GOOGLE, you should specify the google endpoint and cx.
"GOOGLE_SEARCH_ENDPOINT": (
"https://customsearch.googleapis.com/customsearch/v1"
),
# Choose the backend. Currently, we support BING and GOOGLE. For
# simplicity, in this demo, if you specify the backend as LEPTON,
# we will use the hosted serverless version of lepton search api
# at https://search-api.lepton.run/ to do the search and RAG, which
# runs the same code (slightly modified and might contain improvements)
# as this demo.
"BACKEND": "LEPTON",
# If you are using google, specify the search cx.
"GOOGLE_SEARCH_CX": "",
# Specify the number of references you want to use. 8 is usually a good number.
"REFERENCE_COUNT": "8",
# Specify the LLM model you are going to use.
"LLM_MODEL": "mixtral-8x7b",
# Specify the default timeout for the search engine. If the search engine
# does not respond within this time, we will return an error.
"DEFAULT_SEARCH_ENGINE_TIMEOUT": "5",
# For all the search queries and results, we will use the Lepton KV to
# store them so that we can retrieve them later. Specify the name of the
# KV here.
Expand Down Expand Up @@ -241,42 +228,41 @@ def local_client(self):

def init(self):
"""
Initializes the environmental variables.
Initializes photon configs.
"""
# First, log in to the workspace.
leptonai.api.workspace.login()
self.reference_count = int(os.environ["REFERENCE_COUNT"])
self.default_search_engine_timeout = int()
if os.environ["BACKEND"].upper() == "BING":
self.backend = os.environ["BACKEND"].upper()
if self.backend == "LEPTON":
self.leptonsearch_client = Client(
"https://search-api.lepton.run/",
token=os.environ.get("LEPTON_WORKSPACE_TOKEN")
or WorkspaceInfoLocalRecord.get_current_workspace_token(),
stream=True,
timeout=httpx.Timeout(connect=10, read=120, write=120, pool=10),
)
elif self.backend == "BING":
self.search_api_key = os.environ["BING_SEARCH_V7_SUBSCRIPTION_KEY"]
self.search_function = lambda query: search_with_bing(
query,
os.environ["BING_SEARCH_V7_ENDPOINT"] + "v7.0/search",
os.environ["BING_MKT"],
os.environ["BING_SEARCH_V7_SUBSCRIPTION_KEY"],
int(os.environ["DEFAULT_SEARCH_ENGINE_TIMEOUT"]),
int(os.environ["REFERENCE_COUNT"]),
self.search_api_key,
)
elif os.environ["BACKEND"].upper() == "GOOGLE":
elif self.backend == "GOOGLE":
self.search_api_key = os.environ["GOOGLE_SEARCH_API_KEY"]
self.search_function = lambda query: search_with_google(
query,
os.environ["GOOGLE_SEARCH_ENDPOINT"],
os.environ["GOOGLE_SEARCH_API_KEY"],
self.search_api_key,
os.environ["GOOGLE_SEARCH_CX"],
int(os.environ["DEFAULT_SEARCH_ENGINE_TIMEOUT"]),
int(os.environ["REFERENCE_COUNT"]),
)
else:
raise RuntimeError("Backend must be either BING or GOOGLE.")
raise RuntimeError("Backend must be LEPTON, BING or GOOGLE.")
self.model = os.environ["LLM_MODEL"]
# An executor to carry out async tasks, such as uploading to KV.
self.executor = concurrent.futures.ThreadPoolExecutor(
max_workers=self.handler_max_concurrency * 2
)
# Create the KV to store the search results.
logger.info(
f"Creating KV {os.environ['KV_NAME']}. If this is the first time, it may"
" take a while."
)
logger.info("Creating KV. May take a while for the first time.")
self.kv = KV(
os.environ["KV_NAME"], create_if_not_exists=True, error_if_exists=False
)
Expand Down Expand Up @@ -326,6 +312,8 @@ def ask_related_questions(
max_tokens=512,
)
related = response.choices[0].message.tool_calls[0].function.arguments
if isinstance(related, str):
related = json.loads(related)
logger.trace(f"Related questions: {related}")
return related["questions"][:5]
except Exception as e:
Expand All @@ -348,6 +336,12 @@ def _raw_stream_response(
yield json.dumps(contexts)
yield "\n\n__LLM_RESPONSE__\n\n"
# Second, yield the llm response.
if not contexts:
# Prepend a warning to the user
yield (
"(The search engine returned nothing for this query. Please take the"
" answer with a grain of salt.)\n\n"
)
for chunk in llm_response:
if chunk.choices:
yield chunk.choices[0].delta.content or ""
Expand Down Expand Up @@ -380,8 +374,8 @@ def stream_and_upload_to_kv(
# ignore it, because we don't want to affect the user experience.
_ = self.executor.submit(self.kv.put, search_uuid, "".join(all_yielded_results))

@Photon.handler(method="POST")
def query(
@Photon.handler(method="POST", path="/query")
def query_function(
self,
query: str,
search_uuid: str,
Expand All @@ -407,14 +401,10 @@ def query(
try:
result = self.kv.get(search_uuid)

def dummy_generator():
def str_to_generator(result: str) -> Generator[str, None, None]:
yield result

return StreamingResponse(
content=dummy_generator(),
status_code=200,
media_type="text/html",
)
return StreamingResponse(str_to_generator(result))
except KeyError:
logger.info(f"Key {search_uuid} not found, will generate again.")
except Exception as e:
Expand All @@ -424,19 +414,26 @@ def dummy_generator():
else:
raise HTTPException(status_code=400, detail="search_uuid must be provided.")

if self.backend == "LEPTON":
# delegate to the lepton search api.
result = self.leptonsearch_client.query(
query=query,
search_uuid=search_uuid,
generate_related_questions=generate_related_questions,
)
return StreamingResponse(content=result, media_type="text/html")

# First, do a search query.
query = query or _default_query
# Basic attack protection: remove "[INST]" or "[/INST]" from the query
query = re.sub(r"\[/?INST\]", "", query)
contexts = self.search_function(query)

concatenated_contexts = "\n\n".join(
[f"[[citation:{i+1}]] {c['snippet']}" for i, c in enumerate(contexts)]
system_prompt = _rag_query_text.format(
context="\n\n".join(
[f"[[citation:{i+1}]] {c['snippet']}" for i, c in enumerate(contexts)]
)
)

system_prompt = _rag_query_text.format(context=concatenated_contexts)

logger.trace(f"System prompt:\n {system_prompt}")
try:
client = self.local_client()
llm_response = client.chat.completions.create(
Expand All @@ -460,10 +457,7 @@ def dummy_generator():
related_questions_future = None
except Exception as e:
logger.error(f"encountered error: {e}\n{traceback.format_exc()}")
return HTMLResponse(
content="Internal server error.",
status_code=503,
)
return HTMLResponse("Internal server error.", 503)

return StreamingResponse(
self.stream_and_upload_to_kv(
Expand Down

0 comments on commit d6889a4

Please sign in to comment.