Skip to content

Commit

Permalink
parse headers for errored requests (predibase#564)
Browse files Browse the repository at this point in the history
  • Loading branch information
noyoshi authored Aug 2, 2024
1 parent ea5d74b commit 3b2cc05
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 11 deletions.
22 changes: 12 additions & 10 deletions clients/python/lorax/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
ClassifyResponse
)
from lorax.errors import parse_error
import os

LORAX_DEBUG_MODE = os.getenv("LORAD_DEBUG_MODE", None) is not None

class Client:
"""Client to make calls to a LoRAX instance
Expand Down Expand Up @@ -272,7 +274,7 @@ def generate(
payload = {"message": e.msg}

if resp.status_code != 200:
raise parse_error(resp.status_code, payload)
raise parse_error(resp.status_code, payload, resp.headers if LORAX_DEBUG_MODE else None)

return Response(**payload[0])

Expand Down Expand Up @@ -392,7 +394,7 @@ def generate_stream(
)

if resp.status_code != 200:
raise parse_error(resp.status_code, resp.json())
raise parse_error(resp.status_code, resp.json(), resp.headers if LORAX_DEBUG_MODE else None)

# Parse ServerSentEvents
for byte_payload in resp.iter_lines():
Expand All @@ -411,7 +413,7 @@ def generate_stream(
response = StreamResponse(**json_payload)
except ValidationError:
# If we failed to parse the payload, then it is an error payload
raise parse_error(resp.status_code, json_payload)
raise parse_error(resp.status_code, json_payload, resp.headers if LORAX_DEBUG_MODE else None)
yield response


Expand All @@ -438,7 +440,7 @@ def embed(self, inputs: str) -> EmbedResponse:

payload = resp.json()
if resp.status_code != 200:
raise parse_error(resp.status_code, resp.json())
raise parse_error(resp.status_code, resp.json(), resp.headers if LORAX_DEBUG_MODE else None)

return EmbedResponse(**payload)

Expand Down Expand Up @@ -466,7 +468,7 @@ def classify(self, inputs: str) -> ClassifyResponse:

payload = resp.json()
if resp.status_code != 200:
raise parse_error(resp.status_code, resp.json())
raise parse_error(resp.status_code, resp.json(), resp.headers if LORAX_DEBUG_MODE else None)

print(payload)
return ClassifyResponse(**payload)
Expand Down Expand Up @@ -645,7 +647,7 @@ async def generate(
payload = await resp.json()

if resp.status != 200:
raise parse_error(resp.status, payload)
raise parse_error(resp.status, payload, resp.headers if LORAX_DEBUG_MODE else None)
return Response(**payload[0])

async def generate_stream(
Expand Down Expand Up @@ -768,7 +770,7 @@ async def generate_stream(
async with session.post(self.base_url, json=request.dict(by_alias=True)) as resp:

if resp.status != 200:
raise parse_error(resp.status, await resp.json())
raise parse_error(resp.status, await resp.json(), resp.headers if LORAX_DEBUG_MODE else None)

# Parse ServerSentEvents
async for byte_payload in resp.content:
Expand All @@ -787,7 +789,7 @@ async def generate_stream(
response = StreamResponse(**json_payload)
except ValidationError:
# If we failed to parse the payload, then it is an error payload
raise parse_error(resp.status, json_payload)
raise parse_error(resp.status, json_payload, resp.headers if LORAX_DEBUG_MODE else None)
yield response


Expand All @@ -810,7 +812,7 @@ async def embed(self, inputs: str) -> EmbedResponse:
payload = await resp.json()

if resp.status != 200:
raise parse_error(resp.status, payload)
raise parse_error(resp.status, payload, resp.headers if LORAX_DEBUG_MODE else None)
return EmbedResponse(**payload)


Expand All @@ -833,5 +835,5 @@ async def classify(self, inputs: str) -> ClassifyResponse:
payload = await resp.json()

if resp.status != 200:
raise parse_error(resp.status, payload)
raise parse_error(resp.status, payload, resp.headers if LORAX_DEBUG_MODE else None)
return ClassifyResponse(**payload)
8 changes: 7 additions & 1 deletion clients/python/lorax/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def __init__(self, message: str, code: int):
super().__init__(f"Error status {code}: {message}")


def parse_error(status_code: int, payload: Dict[str, str]) -> Exception:
def parse_error(status_code: int, payload: Dict[str, str], headers: Dict[str, str] = None) -> Exception:
"""
Parse error given an HTTP status code and a json payload
Expand All @@ -79,9 +79,15 @@ def parse_error(status_code: int, payload: Dict[str, str]) -> Exception:
Exception: parsed exception
"""
trace_id = ""
if headers:
trace_id = headers.get("x-b3-traceid", "")
# Try to parse a LoRAX error
message = payload.get("error", "")

if trace_id:
message += f": Trace ID: {trace_id}"

error_type = payload.get("error_type", "")
if error_type == "generation":
return GenerationError(message)
Expand Down

0 comments on commit 3b2cc05

Please sign in to comment.