From 31bb456947c4521378949abe2cca046f1e284562 Mon Sep 17 00:00:00 2001 From: Yangqing Jia Date: Tue, 23 Jan 2024 00:35:39 -0800 Subject: [PATCH 1/2] enable root path redirection --- search_with_lepton.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/search_with_lepton.py b/search_with_lepton.py index b1b4d99..09977fb 100644 --- a/search_with_lepton.py +++ b/search_with_lepton.py @@ -8,7 +8,7 @@ import traceback from typing import Annotated, List, Generator, Optional -from fastapi.responses import HTMLResponse, StreamingResponse +from fastapi.responses import HTMLResponse, StreamingResponse, RedirectResponse import httpx from loguru import logger @@ -319,12 +319,10 @@ def ask_related_questions( "content": query, }, ], - tools=[ - { - "type": "function", - "function": tool.get_tools_spec(ask_related_questions), - } - ], + tools=[{ + "type": "function", + "function": tool.get_tools_spec(ask_related_questions), + }], max_tokens=512, ) related = response.choices[0].message.tool_calls[0].function.arguments @@ -478,6 +476,13 @@ def dummy_generator(): def ui(self): return StaticFiles(directory="ui") + @Photon.handler(method="GET", path="/") + def index(self) -> RedirectResponse: + """ + Redirects "/" to the ui page. + """ + return RedirectResponse(url="/ui/index.html") + if __name__ == "__main__": rag = RAG() From d6889a461cbd3b918bfa452fcb17d4dbf23f29e5 Mon Sep 17 00:00:00 2001 From: Yangqing Jia Date: Thu, 25 Jan 2024 00:41:07 -0800 Subject: [PATCH 2/2] refactor --- search_with_lepton.py | 162 ++++++++++++++++++++---------------------- 1 file changed, 78 insertions(+), 84 deletions(-) diff --git a/search_with_lepton.py b/search_with_lepton.py index 09977fb..1a08d57 100644 --- a/search_with_lepton.py +++ b/search_with_lepton.py @@ -13,6 +13,7 @@ from loguru import logger import leptonai +from leptonai import Client from leptonai.kv import KV from leptonai.photon import HTTPException, Photon, StaticFiles from leptonai.photon.types import to_bool @@ -23,6 +24,20 @@ # Constant values for the RAG model. ################################################################################ +# Search engine related. You don't really need to change this. +BING_SEARCH_V7_ENDPOINT = "https://api.bing.microsoft.com/v7.0/search" +BING_MKT = "en-US" +GOOGLE_SEARCH_ENDPOINT = "https://customsearch.googleapis.com/customsearch/v1" + +# Specify the number of references from the search engine you want to use. +# 8 is usually a good number. +REFERENCE_COUNT = 8 + +# Specify the default timeout for the search engine. If the search engine +# does not respond within this time, we will return an error. +DEFAULT_SEARCH_ENGINE_TIMEOUT = 5 + + # If the user did not provide a query, we will use this default query. _default_query = "Who said 'live long and prosper'?" @@ -76,24 +91,16 @@ """ -def search_with_bing( - query, - endpoint, - bing_mkt, - subscription_key, - default_search_engine_timeout, - reference_count, -): +def search_with_bing(query: str, subscription_key: str): """ Search with bing and return the contexts. """ - params = {"q": query, "mkt": bing_mkt} - logger.info(f"{endpoint} {params} {subscription_key}") + params = {"q": query, "mkt": BING_MKT} response = requests.get( - endpoint, + BING_SEARCH_V7_ENDPOINT, headers={"Ocp-Apim-Subscription-Key": subscription_key}, params=params, - timeout=default_search_engine_timeout, + timeout=DEFAULT_SEARCH_ENGINE_TIMEOUT, ) if not response.ok: logger.error(f"{response.status_code} {response.text}") @@ -103,24 +110,14 @@ def search_with_bing( ) json_content = response.json() try: - contexts = json_content["webPages"]["value"][:reference_count] + contexts = json_content["webPages"]["value"][:REFERENCE_COUNT] except KeyError: logger.error(f"Error encountered: {json_content}") - raise HTTPException( - status_code=500, - detail="Search engine error.", - ) + return [] return contexts -def search_with_google( - query, - endpoint, - subscription_key, - cx, - default_search_engine_timeout, - reference_count, -): +def search_with_google(query: str, subscription_key: str, cx: str): """ Search with google and return the contexts. """ @@ -128,10 +125,10 @@ def search_with_google( "key": subscription_key, "cx": cx, "q": query, - "num": reference_count, + "num": REFERENCE_COUNT, } response = requests.get( - endpoint, params=params, timeout=default_search_engine_timeout + GOOGLE_SEARCH_ENDPOINT, params=params, timeout=DEFAULT_SEARCH_ENGINE_TIMEOUT ) if not response.ok: logger.error(f"{response.status_code} {response.text}") @@ -141,13 +138,10 @@ def search_with_google( ) json_content = response.json() try: - contexts = json_content["items"][:reference_count] + contexts = json_content["items"][:REFERENCE_COUNT] except KeyError: logger.error(f"Error encountered: {json_content}") - raise HTTPException( - status_code=500, - detail="Search engine error.", - ) + return [] return contexts @@ -174,24 +168,17 @@ class RAG(Photon): "resource_shape": "cpu.small", # You most likely don't need to change this. "env": { - # Choose the backend. Currently, we support BING and GOOGLE. - "BACKEND": "BING", - # If you are using BING, you should specify the bing endpoint and mkt. - # If you do not know what this is, you can leave it as is. - "BING_SEARCH_V7_ENDPOINT": "https://api.bing.microsoft.com/", - "BING_MKT": "en-US", - # If you are using GOOGLE, you should specify the google endpoint and cx. - "GOOGLE_SEARCH_ENDPOINT": ( - "https://customsearch.googleapis.com/customsearch/v1" - ), + # Choose the backend. Currently, we support BING and GOOGLE. For + # simplicity, in this demo, if you specify the backend as LEPTON, + # we will use the hosted serverless version of lepton search api + # at https://search-api.lepton.run/ to do the search and RAG, which + # runs the same code (slightly modified and might contain improvements) + # as this demo. + "BACKEND": "LEPTON", + # If you are using google, specify the search cx. "GOOGLE_SEARCH_CX": "", - # Specify the number of references you want to use. 8 is usually a good number. - "REFERENCE_COUNT": "8", # Specify the LLM model you are going to use. "LLM_MODEL": "mixtral-8x7b", - # Specify the default timeout for the search engine. If the search engine - # does not respond within this time, we will return an error. - "DEFAULT_SEARCH_ENGINE_TIMEOUT": "5", # For all the search queries and results, we will use the Lepton KV to # store them so that we can retrieve them later. Specify the name of the # KV here. @@ -241,42 +228,41 @@ def local_client(self): def init(self): """ - Initializes the environmental variables. + Initializes photon configs. """ # First, log in to the workspace. leptonai.api.workspace.login() - self.reference_count = int(os.environ["REFERENCE_COUNT"]) - self.default_search_engine_timeout = int() - if os.environ["BACKEND"].upper() == "BING": + self.backend = os.environ["BACKEND"].upper() + if self.backend == "LEPTON": + self.leptonsearch_client = Client( + "https://search-api.lepton.run/", + token=os.environ.get("LEPTON_WORKSPACE_TOKEN") + or WorkspaceInfoLocalRecord.get_current_workspace_token(), + stream=True, + timeout=httpx.Timeout(connect=10, read=120, write=120, pool=10), + ) + elif self.backend == "BING": + self.search_api_key = os.environ["BING_SEARCH_V7_SUBSCRIPTION_KEY"] self.search_function = lambda query: search_with_bing( query, - os.environ["BING_SEARCH_V7_ENDPOINT"] + "v7.0/search", - os.environ["BING_MKT"], - os.environ["BING_SEARCH_V7_SUBSCRIPTION_KEY"], - int(os.environ["DEFAULT_SEARCH_ENGINE_TIMEOUT"]), - int(os.environ["REFERENCE_COUNT"]), + self.search_api_key, ) - elif os.environ["BACKEND"].upper() == "GOOGLE": + elif self.backend == "GOOGLE": + self.search_api_key = os.environ["GOOGLE_SEARCH_API_KEY"] self.search_function = lambda query: search_with_google( query, - os.environ["GOOGLE_SEARCH_ENDPOINT"], - os.environ["GOOGLE_SEARCH_API_KEY"], + self.search_api_key, os.environ["GOOGLE_SEARCH_CX"], - int(os.environ["DEFAULT_SEARCH_ENGINE_TIMEOUT"]), - int(os.environ["REFERENCE_COUNT"]), ) else: - raise RuntimeError("Backend must be either BING or GOOGLE.") + raise RuntimeError("Backend must be LEPTON, BING or GOOGLE.") self.model = os.environ["LLM_MODEL"] # An executor to carry out async tasks, such as uploading to KV. self.executor = concurrent.futures.ThreadPoolExecutor( max_workers=self.handler_max_concurrency * 2 ) # Create the KV to store the search results. - logger.info( - f"Creating KV {os.environ['KV_NAME']}. If this is the first time, it may" - " take a while." - ) + logger.info("Creating KV. May take a while for the first time.") self.kv = KV( os.environ["KV_NAME"], create_if_not_exists=True, error_if_exists=False ) @@ -326,6 +312,8 @@ def ask_related_questions( max_tokens=512, ) related = response.choices[0].message.tool_calls[0].function.arguments + if isinstance(related, str): + related = json.loads(related) logger.trace(f"Related questions: {related}") return related["questions"][:5] except Exception as e: @@ -348,6 +336,12 @@ def _raw_stream_response( yield json.dumps(contexts) yield "\n\n__LLM_RESPONSE__\n\n" # Second, yield the llm response. + if not contexts: + # Prepend a warning to the user + yield ( + "(The search engine returned nothing for this query. Please take the" + " answer with a grain of salt.)\n\n" + ) for chunk in llm_response: if chunk.choices: yield chunk.choices[0].delta.content or "" @@ -380,8 +374,8 @@ def stream_and_upload_to_kv( # ignore it, because we don't want to affect the user experience. _ = self.executor.submit(self.kv.put, search_uuid, "".join(all_yielded_results)) - @Photon.handler(method="POST") - def query( + @Photon.handler(method="POST", path="/query") + def query_function( self, query: str, search_uuid: str, @@ -407,14 +401,10 @@ def query( try: result = self.kv.get(search_uuid) - def dummy_generator(): + def str_to_generator(result: str) -> Generator[str, None, None]: yield result - return StreamingResponse( - content=dummy_generator(), - status_code=200, - media_type="text/html", - ) + return StreamingResponse(str_to_generator(result)) except KeyError: logger.info(f"Key {search_uuid} not found, will generate again.") except Exception as e: @@ -424,19 +414,26 @@ def dummy_generator(): else: raise HTTPException(status_code=400, detail="search_uuid must be provided.") + if self.backend == "LEPTON": + # delegate to the lepton search api. + result = self.leptonsearch_client.query( + query=query, + search_uuid=search_uuid, + generate_related_questions=generate_related_questions, + ) + return StreamingResponse(content=result, media_type="text/html") + # First, do a search query. query = query or _default_query # Basic attack protection: remove "[INST]" or "[/INST]" from the query query = re.sub(r"\[/?INST\]", "", query) contexts = self.search_function(query) - concatenated_contexts = "\n\n".join( - [f"[[citation:{i+1}]] {c['snippet']}" for i, c in enumerate(contexts)] + system_prompt = _rag_query_text.format( + context="\n\n".join( + [f"[[citation:{i+1}]] {c['snippet']}" for i, c in enumerate(contexts)] + ) ) - - system_prompt = _rag_query_text.format(context=concatenated_contexts) - - logger.trace(f"System prompt:\n {system_prompt}") try: client = self.local_client() llm_response = client.chat.completions.create( @@ -460,10 +457,7 @@ def dummy_generator(): related_questions_future = None except Exception as e: logger.error(f"encountered error: {e}\n{traceback.format_exc()}") - return HTMLResponse( - content="Internal server error.", - status_code=503, - ) + return HTMLResponse("Internal server error.", 503) return StreamingResponse( self.stream_and_upload_to_kv(