Skip to content

Commit

Permalink
fixed python client to do streaming inference properly (tensorzero#402)
Browse files Browse the repository at this point in the history
* fixed python client to do streaming inference properly

* added raise for status on streaming

* bumped python client version

* fixed uv lock
  • Loading branch information
virajmehta authored Oct 21, 2024
1 parent bae21b1 commit 66af7c9
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 22 deletions.
2 changes: 1 addition & 1 deletion clients/python/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "tensorzero"
version = "2024.09.3"
version = "2024.10.0"
description = "The Python client for TensorZero"
readme = "README.md"
requires-python = ">=3.10"
Expand Down
46 changes: 30 additions & 16 deletions clients/python/src/tensorzero/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,15 +220,21 @@ def inference(
tool_choice,
parallel_tool_calls,
)
response = self.client.post(url, json=data)
try:
response.raise_for_status()
except httpx.HTTPStatusError as e:
raise TensorZeroError(response) from e
if not stream:
return parse_inference_response(response.json())
else:
if stream:
req = self.client.build_request("POST", url, json=data)
response = self.client.send(req, stream=True)
try:
response.raise_for_status()
except httpx.HTTPStatusError as e:
raise TensorZeroError(response) from e
return self._stream_sse(response)
else:
response = self.client.post(url, json=data)
try:
response.raise_for_status()
except httpx.HTTPStatusError as e:
raise TensorZeroError(response) from e
return parse_inference_response(response.json())

def feedback(
self,
Expand Down Expand Up @@ -296,6 +302,7 @@ def _stream_sse(
yield parse_inference_chunk(data)
except json.JSONDecodeError:
self.logger.error(f"Failed to parse SSE data: {data}")
response.close()


class AsyncTensorZeroGateway(BaseTensorZeroGateway):
Expand Down Expand Up @@ -366,15 +373,21 @@ async def inference(
tool_choice,
parallel_tool_calls,
)
response = await self.client.post(url, json=data)
try:
response.raise_for_status()
except httpx.HTTPStatusError as e:
raise TensorZeroError(response) from e
if not stream:
return parse_inference_response(response.json())
else:
if stream:
req = self.client.build_request("POST", url, json=data)
response = await self.client.send(req, stream=True)
try:
response.raise_for_status()
except httpx.HTTPStatusError as e:
raise TensorZeroError(response) from e
return self._stream_sse(response)
else:
response = await self.client.post(url, json=data)
try:
response.raise_for_status()
except httpx.HTTPStatusError as e:
raise TensorZeroError(response) from e
return parse_inference_response(response.json())

async def feedback(
self,
Expand Down Expand Up @@ -442,3 +455,4 @@ async def _stream_sse(
yield parse_inference_chunk(data)
except json.JSONDecodeError:
self.logger.error(f"Failed to parse SSE data: {data}")
await response.aclose()
24 changes: 20 additions & 4 deletions clients/python/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
```
"""

from time import sleep
from time import sleep, time
from uuid import UUID

import pytest
Expand Down Expand Up @@ -67,6 +67,7 @@ async def test_async_basic_inference(async_client):

@pytest.mark.asyncio
async def test_async_inference_streaming(async_client):
start_time = time()
stream = await async_client.inference(
function_name="basic_test",
input={
Expand All @@ -75,7 +76,14 @@ async def test_async_inference_streaming(async_client):
},
stream=True,
)
chunks = [chunk async for chunk in stream]
first_chunk_duration = None
chunks = []
async for chunk in stream:
chunks.append(chunk)
if first_chunk_duration is None:
first_chunk_duration = time() - start_time
last_chunk_duration = time() - start_time - first_chunk_duration
assert last_chunk_duration > first_chunk_duration + 0.1
expected_text = [
"Wally,",
" the",
Expand Down Expand Up @@ -394,6 +402,7 @@ def test_sync_basic_inference(sync_client):


def test_sync_inference_streaming(sync_client):
start_time = time()
stream = sync_client.inference(
function_name="basic_test",
input={
Expand All @@ -402,8 +411,15 @@ def test_sync_inference_streaming(sync_client):
},
stream=True,
)
chunks = list(stream)
print(chunks)
first_chunk_duration = None
chunks = []
for chunk in stream:
chunks.append(chunk)
if first_chunk_duration is None:
first_chunk_duration = time() - start_time
last_chunk_duration = time() - start_time - first_chunk_duration
assert last_chunk_duration > first_chunk_duration + 0.1

expected_text = [
"Wally,",
" the",
Expand Down
2 changes: 1 addition & 1 deletion clients/python/uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 66af7c9

Please sign in to comment.