Skip to content

Commit

Permalink
feat: add async test for proxy blacklist validation
Browse files Browse the repository at this point in the history
  • Loading branch information
LinkW77 authored and jameszyao committed Jul 16, 2024
1 parent 4d60ad8 commit 8717780
Show file tree
Hide file tree
Showing 9 changed files with 161 additions and 1 deletion.
7 changes: 7 additions & 0 deletions inference/app/routes/chat_completion/route.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
ModelSchema,
BaseModelProperties,
)
from config import CONFIG
from .schema import *

router = APIRouter()
Expand Down Expand Up @@ -210,6 +211,12 @@ async def api_chat_completion(
ErrorCode.REQUEST_VALIDATION_ERROR, "Model type should be chat_completion, but got " + model_type
)

# check if proxy is blacklisted
if data.proxy:
for url in CONFIG.PROVIDER_URL_BLACK_LIST:
if url in data.proxy:
raise_http_error(ErrorCode.REQUEST_VALIDATION_ERROR, f"Invalid provider url: {url}")

if data.stream:

async def generator():
Expand Down
8 changes: 8 additions & 0 deletions inference/app/routes/rerank/route.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from app.models.tokenizer import string_tokens
from app.models.rerank import *
from .schema import *
from config import CONFIG
import logging

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -50,6 +51,13 @@ async def api_rerank(
raise_http_error(ErrorCode.REQUEST_VALIDATION_ERROR, "Model type should be rerank, but got " + model_type)

(model_schema, provider_model_id, properties, _) = model_infos[0]

# check if proxy is blacklisted
if data.proxy:
for url in CONFIG.PROVIDER_URL_BLACK_LIST:
if url in data.proxy:
raise_http_error(ErrorCode.REQUEST_VALIDATION_ERROR, f"Invalid provider url: {url}")

try:
model = get_rerank_model(provider_id=model_schema.provider_id)
if not model:
Expand Down
8 changes: 8 additions & 0 deletions inference/app/routes/text_embedding/route.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import logging
from typing import List, Optional
import numpy as np
from config import CONFIG

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -165,6 +166,13 @@ async def api_text_embedding(

default_embedding_size = model_infos[0][2].embedding_size
last_exception = None

# check if proxy is blacklisted
if data.proxy:
for url in CONFIG.PROVIDER_URL_BLACK_LIST:
if url in data.proxy:
raise_http_error(ErrorCode.REQUEST_VALIDATION_ERROR, f"Invalid provider url: {url}")

for i, (model_schema, provider_model_id, properties, _) in enumerate(model_infos):
properties: TextEmbeddingModelProperties
if default_embedding_size != properties.embedding_size:
Expand Down
23 changes: 22 additions & 1 deletion inference/app/routes/verify/route.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from app.error import ErrorCode, raise_http_error, TKHttpException, error_messages
from aiohttp import client_exceptions
import logging
from config import CONFIG

logger = logging.getLogger(__name__)

Expand All @@ -32,7 +33,6 @@
async def api_verify_credentials(
data: VerifyModelCredentialsSchema,
):

model_infos = [
validate_model_info(
model_schema_id=data.model_schema_id,
Expand All @@ -49,6 +49,13 @@ async def api_verify_credentials(
encrypted_credentials_dict=data.encrypted_credentials,
)
model_schema, provider_model_id, properties, model_type = model_infos[0]

# check if proxy is blacklisted
if data.proxy:
for url in CONFIG.PROVIDER_URL_BLACK_LIST:
if url in data.proxy:
raise_http_error(ErrorCode.REQUEST_VALIDATION_ERROR, f"Invalid provider url: {url}")

try:
if model_type == ModelType.CHAT_COMPLETION:
from ..chat_completion.route import chat_completion, chat_completion_stream
Expand All @@ -71,6 +78,8 @@ async def api_verify_credentials(
messages=[message],
credentials=provider_credentials,
configs=config,
proxy=data.proxy,
custom_headers=data.custom_headers,
)
if response.message.content is None:
raise_http_error(ErrorCode.CREDENTIALS_VALIDATION_ERROR, error_message)
Expand All @@ -81,6 +90,8 @@ async def api_verify_credentials(
messages=[message],
credentials=provider_credentials,
configs=config,
proxy=data.proxy,
custom_headers=data.custom_headers,
):
if isinstance(response, ChatCompletion) and response.message.content is not None:
valid_response_received = True
Expand Down Expand Up @@ -119,6 +130,8 @@ async def api_verify_credentials(
credentials=provider_credentials,
configs=config,
functions=[ChatCompletionFunction(**function_dict)],
proxy=data.proxy,
custom_headers=data.custom_headers,
)
if response.message.content is None and not response.message.function_calls:
raise_http_error(ErrorCode.CREDENTIALS_VALIDATION_ERROR, error_message)
Expand All @@ -130,6 +143,8 @@ async def api_verify_credentials(
credentials=provider_credentials,
configs=config,
functions=[ChatCompletionFunction(**function_dict)],
proxy=data.proxy,
custom_headers=data.custom_headers,
):
if isinstance(response, ChatCompletion) and (
response.message.content is not None or response.message.function_calls
Expand All @@ -145,6 +160,8 @@ async def api_verify_credentials(
messages=[message],
credentials=provider_credentials,
configs=config,
proxy=data.proxy,
custom_headers=data.custom_headers,
)
if response.message.content is None:
raise_http_error(ErrorCode.CREDENTIALS_VALIDATION_ERROR, error_message)
Expand All @@ -160,6 +177,8 @@ async def api_verify_credentials(
properties=properties,
configs=TextEmbeddingModelConfiguration(),
input_type=None,
proxy=data.proxy,
custom_headers=data.custom_headers,
)
actual_embedding_size = len(response.data[0].embedding)
if not actual_embedding_size == properties.embedding_size:
Expand All @@ -180,6 +199,8 @@ async def api_verify_credentials(
],
top_n=3,
credentials=provider_credentials,
proxy=data.proxy,
custom_headers=data.custom_headers,
)
else:
raise_http_error(
Expand Down
11 changes: 11 additions & 0 deletions inference/app/routes/verify/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,14 @@ class VerifyModelCredentialsSchema(BaseModel):
description="The encrypted credentials of the model provider to be verified.",
examples=[None],
)

proxy: Optional[str] = Field(None, description="The proxy of the model.")

custom_headers: Optional[Dict[str, str]] = Field(
None,
min_items=0,
max_items=16,
description="The custom headers can store up to 16 key-value pairs where each key's "
"length is less than 64 and value's length is less than 512.",
examples=[{"key1": "value1"}, {"key2": "value2"}],
)
32 changes: 32 additions & 0 deletions inference/test/test_chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,6 +593,38 @@ async def test_chat_completion_by_proxy(self, test_data):
assert res_json.get("data").get("message").get("content") is not None
assert res_json.get("data").get("message").get("function_calls") is None

@pytest.mark.asyncio
@pytest.mark.test_id("inference_030")
@pytest.mark.flaky(reruns=3, reruns_delay=1)
@pytest.mark.parametrize("provider_url", Config.PROVIDER_URL_BLACK_LIST)
async def test_chat_completion_by_error_proxy(self, provider_url):
model_schema_id = "openai/gpt-4o"
message = [{"role": "user", "content": "Hello, nice to meet you, what is your name"}]
configs = {
"temperature": 0.5,
"top_p": 0.5,
}
proxy = provider_url
custom_headers = {"Helicone-Auth": f"Bearer {Config.HELICONE_API_KEY}"}
request_data = {
"model_schema_id": model_schema_id,
"messages": message,
"stream": False,
"configs": configs,
"proxy": proxy,
"custom_headers": custom_headers,
}
try:
res = await asyncio.wait_for(chat_completion(request_data), timeout=120)
except asyncio.TimeoutError:
pytest.skip("Skipping test due to timeout after 2 minutes.")
if is_provider_service_error(res):
pytest.skip("Skip the test case with provider service error.")
assert res.status_code == 422, f"test_validation failed: result={res.json()}"
assert res.json()["status"] == "error"
assert res.json()["error"]["code"] == "REQUEST_VALIDATION_ERROR"
await asyncio.sleep(1)

@pytest.mark.asyncio
@pytest.mark.parametrize(
"test_data",
Expand Down
24 changes: 24 additions & 0 deletions inference/test/test_rerank.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import asyncio
from test.inference_service.inference import rerank
from .utils.utils import generate_test_cases, generate_wildcard_test_cases, check_order, is_provider_service_error
from test.setting import Config


@allure.epic("inference_service")
Expand Down Expand Up @@ -118,3 +119,26 @@ async def test_less_rerank(self, test_data):
assert res.status_code == 422, res.json()
assert res_json.get("status") == "error"
assert res_json.get("error").get("code") == "REQUEST_VALIDATION_ERROR"

@pytest.mark.asyncio
@pytest.mark.test_id("inference_031")
@pytest.mark.parametrize("provider_url", Config.PROVIDER_URL_BLACK_LIST)
@pytest.mark.flaky(reruns=3, reruns_delay=1)
async def test_rerank_with_error_proxy(self, provider_url):
model_schema_id = "cohere/rerank-english-v2.0"
request_data = {
"model_schema_id": model_schema_id,
"query": self.query,
"documents": self.documents,
"top_n": self.top_n,
"proxy": provider_url,
}
try:
res = await asyncio.wait_for(rerank(request_data), timeout=120)
except asyncio.TimeoutError:
pytest.skip("Skipping test due to timeout after 2 minutes.")
if is_provider_service_error(res):
pytest.skip(f"Skip the test case with provider service error.")
assert res.status_code == 422, f"test_validation failed: result={res.json()}"
assert res.json()["status"] == "error"
assert res.json()["error"]["code"] == "REQUEST_VALIDATION_ERROR"
19 changes: 19 additions & 0 deletions inference/test/test_text_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,3 +305,22 @@ async def test_text_embedding_with_error_provider_url(self, test_data, provider_
assert res.status_code == 422, f"test_validation failed: result={res.json()}"
assert res.json()["status"] == "error"
assert res.json()["error"]["code"] == "REQUEST_VALIDATION_ERROR"

@pytest.mark.test_id("inference_011")
@pytest.mark.asyncio
@pytest.mark.parametrize("provider_url", Config.PROVIDER_URL_BLACK_LIST)
async def test_text_embedding_with_error_proxy(self, provider_url):
model_schema_id = "openai/text-embedding-3-large-256"
data = {"model_schema_id": model_schema_id}
data.update(self.single_text)
data.update({"proxy": provider_url})
try:
res = await asyncio.wait_for(text_embedding(data), timeout=120)
except asyncio.TimeoutError:
pytest.skip("Skipping test due to timeout after 2 minutes.")
if is_provider_service_error(res):
pytest.skip(f"Skip the test case with provider service error.")
assert res.status_code == 422, f"test_validation failed: result={res.json()}"
assert res.json()["status"] == "error"
assert res.json()["error"]["code"] == "REQUEST_VALIDATION_ERROR"
await asyncio.sleep(1)
30 changes: 30 additions & 0 deletions inference/test/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,3 +241,33 @@ async def test_validation_with_error_provider_url(self, test_data, provider_url)
assert res.json()["status"] == "error"
assert res.json()["error"]["code"] == "REQUEST_VALIDATION_ERROR"
await asyncio.sleep(1)

@pytest.mark.parametrize("test_data", generate_test_cases_for_validation(), ids=lambda d: d["model_schema_id"])
@pytest.mark.parametrize("provider_url", Config.PROVIDER_URL_BLACK_LIST)
@pytest.mark.asyncio
@pytest.mark.test_id("inference_028")
async def test_validation_with_error_proxy(self, test_data, provider_url):
model_schema_id = test_data["model_schema_id"]
if "openai" not in test_data["model_schema_id"]:
pytest.skip("Test not applicable for this model type")
model_type = test_data["model_type"]
credentials = {
key: provider_credentials.aes_decrypt(test_data["credentials"][key])
for key in test_data["credentials"].keys()
}
custom_headers = {"Helicone-Auth": f"Bearer {Config.HELICONE_API_KEY}"}
request_data = {
"model_schema_id": model_schema_id,
"model_type": model_type,
"credentials": credentials,
"proxy": provider_url,
"custom_headers": custom_headers,
}
try:
res = await asyncio.wait_for(verify_credentials(request_data), timeout=120)
except asyncio.TimeoutError:
pytest.skip("Skipping test due to timeout after 2 minutes.")
assert res.status_code == 422, f"test_validation failed: result={res.json()}"
assert res.json()["status"] == "error"
assert res.json()["error"]["code"] == "REQUEST_VALIDATION_ERROR"
await asyncio.sleep(1)

0 comments on commit 8717780

Please sign in to comment.