Skip to content

Commit

Permalink
fix: handle stream mode for Hugging Face
Browse files Browse the repository at this point in the history
  • Loading branch information
LinkW77 authored and DynamesC committed Aug 22, 2024
1 parent 5d18158 commit 74b0e95
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 3 deletions.
36 changes: 33 additions & 3 deletions inference/providers/hugging_face/chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def _build_hugging_face_text_generation_payload(
"max_new_tokens": configs.max_tokens,
"top_k": configs.top_k,
},
"stream": stream,
}
return payload

Expand All @@ -70,8 +71,6 @@ async def prepare_request(
base_url = "https://api-inference.huggingface.co/models/PLACE_HOLDER_MODEL_ID"
api_url = base_url.replace("PLACE_HOLDER_MODEL_ID", provider_model_id)
headers = _build_hugging_face_header(credentials)
if stream:
raise_http_error(ErrorCode.REQUEST_VALIDATION_ERROR, "Hugging Face does not support streaming.")
if functions:
raise_http_error(ErrorCode.REQUEST_VALIDATION_ERROR, "Hugging Face does not support function calls.")
payload = _build_hugging_face_text_generation_payload(messages, stream, provider_model_id, configs)
Expand All @@ -92,4 +91,35 @@ def extract_function_calls(self, data: Dict, **kwargs) -> Optional[List[ChatComp
pass

def extract_finish_reason(self, data: Dict, **kwargs) -> Optional[ChatCompletionFinishReason]:
return ChatCompletionFinishReason.unknown
return ChatCompletionFinishReason.stop

# ------------------- handle stream chat completion response -------------------

def stream_check_error(self, sse_data: Dict, **kwargs):
if sse_data.get("error"):
raise_provider_api_error(sse_data["error"])

def stream_extract_chunk_data(self, sse_data: Dict, **kwargs) -> Optional[Dict]:
if not sse_data.get("generated_text"):
return None
return sse_data

def stream_extract_chunk(
self, index: int, chunk_data: Dict, text_content: str, **kwargs
) -> Tuple[int, Optional[ChatCompletionChunk]]:
content = chunk_data.get("generated_text", None)
if content:
return index + 1, ChatCompletionChunk(
created_timestamp=get_current_timestamp_int(),
index=index,
delta=content,
)
return index, None

def stream_extract_finish_reason(self, chunk_data: Dict, **kwargs) -> Optional[ChatCompletionFinishReason]:
return ChatCompletionFinishReason.stop

def stream_handle_function_calls(
self, chunk_data: Dict, function_calls_content: ChatCompletionFunctionCallsContent, **kwargs
) -> Optional[ChatCompletionFunctionCallsContent]:
pass
3 changes: 3 additions & 0 deletions inference/providers/hugging_face/resources/provider.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ name: "i18n:hugging_face_name"
description: "i18n:hugging_face_description"
updated_timestamp: 1707152831000

return_token_usage: false
return_stream_token_usage: false

credentials_schema:
type: object
properties:
Expand Down

0 comments on commit 74b0e95

Please sign in to comment.