From 871778067e0e548539fe070deac9657a41f8bec6 Mon Sep 17 00:00:00 2001 From: Link Date: Fri, 28 Jun 2024 10:35:48 +0800 Subject: [PATCH] feat: add async test for proxy blacklist validation --- inference/app/routes/chat_completion/route.py | 7 ++++ inference/app/routes/rerank/route.py | 8 +++++ inference/app/routes/text_embedding/route.py | 8 +++++ inference/app/routes/verify/route.py | 23 ++++++++++++- inference/app/routes/verify/schema.py | 11 +++++++ inference/test/test_chat_completion.py | 32 +++++++++++++++++++ inference/test/test_rerank.py | 24 ++++++++++++++ inference/test/test_text_embedding.py | 19 +++++++++++ inference/test/test_validation.py | 30 +++++++++++++++++ 9 files changed, 161 insertions(+), 1 deletion(-) diff --git a/inference/app/routes/chat_completion/route.py b/inference/app/routes/chat_completion/route.py index 7f378df7..911bb881 100644 --- a/inference/app/routes/chat_completion/route.py +++ b/inference/app/routes/chat_completion/route.py @@ -19,6 +19,7 @@ ModelSchema, BaseModelProperties, ) +from config import CONFIG from .schema import * router = APIRouter() @@ -210,6 +211,12 @@ async def api_chat_completion( ErrorCode.REQUEST_VALIDATION_ERROR, "Model type should be chat_completion, but got " + model_type ) + # check if proxy is blacklisted + if data.proxy: + for url in CONFIG.PROVIDER_URL_BLACK_LIST: + if url in data.proxy: + raise_http_error(ErrorCode.REQUEST_VALIDATION_ERROR, f"Invalid provider url: {url}") + if data.stream: async def generator(): diff --git a/inference/app/routes/rerank/route.py b/inference/app/routes/rerank/route.py index db464681..4474ef7a 100644 --- a/inference/app/routes/rerank/route.py +++ b/inference/app/routes/rerank/route.py @@ -9,6 +9,7 @@ from app.models.tokenizer import string_tokens from app.models.rerank import * from .schema import * +from config import CONFIG import logging logger = logging.getLogger(__name__) @@ -50,6 +51,13 @@ async def api_rerank( raise_http_error(ErrorCode.REQUEST_VALIDATION_ERROR, "Model type should be rerank, but got " + model_type) (model_schema, provider_model_id, properties, _) = model_infos[0] + + # check if proxy is blacklisted + if data.proxy: + for url in CONFIG.PROVIDER_URL_BLACK_LIST: + if url in data.proxy: + raise_http_error(ErrorCode.REQUEST_VALIDATION_ERROR, f"Invalid provider url: {url}") + try: model = get_rerank_model(provider_id=model_schema.provider_id) if not model: diff --git a/inference/app/routes/text_embedding/route.py b/inference/app/routes/text_embedding/route.py index 68ac3040..48a2ece7 100644 --- a/inference/app/routes/text_embedding/route.py +++ b/inference/app/routes/text_embedding/route.py @@ -13,6 +13,7 @@ import logging from typing import List, Optional import numpy as np +from config import CONFIG logger = logging.getLogger(__name__) @@ -165,6 +166,13 @@ async def api_text_embedding( default_embedding_size = model_infos[0][2].embedding_size last_exception = None + + # check if proxy is blacklisted + if data.proxy: + for url in CONFIG.PROVIDER_URL_BLACK_LIST: + if url in data.proxy: + raise_http_error(ErrorCode.REQUEST_VALIDATION_ERROR, f"Invalid provider url: {url}") + for i, (model_schema, provider_model_id, properties, _) in enumerate(model_infos): properties: TextEmbeddingModelProperties if default_embedding_size != properties.embedding_size: diff --git a/inference/app/routes/verify/route.py b/inference/app/routes/verify/route.py index bc0ce9cf..16ce779e 100644 --- a/inference/app/routes/verify/route.py +++ b/inference/app/routes/verify/route.py @@ -17,6 +17,7 @@ from app.error import ErrorCode, raise_http_error, TKHttpException, error_messages from aiohttp import client_exceptions import logging +from config import CONFIG logger = logging.getLogger(__name__) @@ -32,7 +33,6 @@ async def api_verify_credentials( data: VerifyModelCredentialsSchema, ): - model_infos = [ validate_model_info( model_schema_id=data.model_schema_id, @@ -49,6 +49,13 @@ async def api_verify_credentials( encrypted_credentials_dict=data.encrypted_credentials, ) model_schema, provider_model_id, properties, model_type = model_infos[0] + + # check if proxy is blacklisted + if data.proxy: + for url in CONFIG.PROVIDER_URL_BLACK_LIST: + if url in data.proxy: + raise_http_error(ErrorCode.REQUEST_VALIDATION_ERROR, f"Invalid provider url: {url}") + try: if model_type == ModelType.CHAT_COMPLETION: from ..chat_completion.route import chat_completion, chat_completion_stream @@ -71,6 +78,8 @@ async def api_verify_credentials( messages=[message], credentials=provider_credentials, configs=config, + proxy=data.proxy, + custom_headers=data.custom_headers, ) if response.message.content is None: raise_http_error(ErrorCode.CREDENTIALS_VALIDATION_ERROR, error_message) @@ -81,6 +90,8 @@ async def api_verify_credentials( messages=[message], credentials=provider_credentials, configs=config, + proxy=data.proxy, + custom_headers=data.custom_headers, ): if isinstance(response, ChatCompletion) and response.message.content is not None: valid_response_received = True @@ -119,6 +130,8 @@ async def api_verify_credentials( credentials=provider_credentials, configs=config, functions=[ChatCompletionFunction(**function_dict)], + proxy=data.proxy, + custom_headers=data.custom_headers, ) if response.message.content is None and not response.message.function_calls: raise_http_error(ErrorCode.CREDENTIALS_VALIDATION_ERROR, error_message) @@ -130,6 +143,8 @@ async def api_verify_credentials( credentials=provider_credentials, configs=config, functions=[ChatCompletionFunction(**function_dict)], + proxy=data.proxy, + custom_headers=data.custom_headers, ): if isinstance(response, ChatCompletion) and ( response.message.content is not None or response.message.function_calls @@ -145,6 +160,8 @@ async def api_verify_credentials( messages=[message], credentials=provider_credentials, configs=config, + proxy=data.proxy, + custom_headers=data.custom_headers, ) if response.message.content is None: raise_http_error(ErrorCode.CREDENTIALS_VALIDATION_ERROR, error_message) @@ -160,6 +177,8 @@ async def api_verify_credentials( properties=properties, configs=TextEmbeddingModelConfiguration(), input_type=None, + proxy=data.proxy, + custom_headers=data.custom_headers, ) actual_embedding_size = len(response.data[0].embedding) if not actual_embedding_size == properties.embedding_size: @@ -180,6 +199,8 @@ async def api_verify_credentials( ], top_n=3, credentials=provider_credentials, + proxy=data.proxy, + custom_headers=data.custom_headers, ) else: raise_http_error( diff --git a/inference/app/routes/verify/schema.py b/inference/app/routes/verify/schema.py index cdc6a995..7c9acc41 100644 --- a/inference/app/routes/verify/schema.py +++ b/inference/app/routes/verify/schema.py @@ -47,3 +47,14 @@ class VerifyModelCredentialsSchema(BaseModel): description="The encrypted credentials of the model provider to be verified.", examples=[None], ) + + proxy: Optional[str] = Field(None, description="The proxy of the model.") + + custom_headers: Optional[Dict[str, str]] = Field( + None, + min_items=0, + max_items=16, + description="The custom headers can store up to 16 key-value pairs where each key's " + "length is less than 64 and value's length is less than 512.", + examples=[{"key1": "value1"}, {"key2": "value2"}], + ) diff --git a/inference/test/test_chat_completion.py b/inference/test/test_chat_completion.py index 5891daa9..8aeff247 100644 --- a/inference/test/test_chat_completion.py +++ b/inference/test/test_chat_completion.py @@ -593,6 +593,38 @@ async def test_chat_completion_by_proxy(self, test_data): assert res_json.get("data").get("message").get("content") is not None assert res_json.get("data").get("message").get("function_calls") is None + @pytest.mark.asyncio + @pytest.mark.test_id("inference_030") + @pytest.mark.flaky(reruns=3, reruns_delay=1) + @pytest.mark.parametrize("provider_url", Config.PROVIDER_URL_BLACK_LIST) + async def test_chat_completion_by_error_proxy(self, provider_url): + model_schema_id = "openai/gpt-4o" + message = [{"role": "user", "content": "Hello, nice to meet you, what is your name"}] + configs = { + "temperature": 0.5, + "top_p": 0.5, + } + proxy = provider_url + custom_headers = {"Helicone-Auth": f"Bearer {Config.HELICONE_API_KEY}"} + request_data = { + "model_schema_id": model_schema_id, + "messages": message, + "stream": False, + "configs": configs, + "proxy": proxy, + "custom_headers": custom_headers, + } + try: + res = await asyncio.wait_for(chat_completion(request_data), timeout=120) + except asyncio.TimeoutError: + pytest.skip("Skipping test due to timeout after 2 minutes.") + if is_provider_service_error(res): + pytest.skip("Skip the test case with provider service error.") + assert res.status_code == 422, f"test_validation failed: result={res.json()}" + assert res.json()["status"] == "error" + assert res.json()["error"]["code"] == "REQUEST_VALIDATION_ERROR" + await asyncio.sleep(1) + @pytest.mark.asyncio @pytest.mark.parametrize( "test_data", diff --git a/inference/test/test_rerank.py b/inference/test/test_rerank.py index 18afd73a..19cd47b6 100644 --- a/inference/test/test_rerank.py +++ b/inference/test/test_rerank.py @@ -3,6 +3,7 @@ import asyncio from test.inference_service.inference import rerank from .utils.utils import generate_test_cases, generate_wildcard_test_cases, check_order, is_provider_service_error +from test.setting import Config @allure.epic("inference_service") @@ -118,3 +119,26 @@ async def test_less_rerank(self, test_data): assert res.status_code == 422, res.json() assert res_json.get("status") == "error" assert res_json.get("error").get("code") == "REQUEST_VALIDATION_ERROR" + + @pytest.mark.asyncio + @pytest.mark.test_id("inference_031") + @pytest.mark.parametrize("provider_url", Config.PROVIDER_URL_BLACK_LIST) + @pytest.mark.flaky(reruns=3, reruns_delay=1) + async def test_rerank_with_error_proxy(self, provider_url): + model_schema_id = "cohere/rerank-english-v2.0" + request_data = { + "model_schema_id": model_schema_id, + "query": self.query, + "documents": self.documents, + "top_n": self.top_n, + "proxy": provider_url, + } + try: + res = await asyncio.wait_for(rerank(request_data), timeout=120) + except asyncio.TimeoutError: + pytest.skip("Skipping test due to timeout after 2 minutes.") + if is_provider_service_error(res): + pytest.skip(f"Skip the test case with provider service error.") + assert res.status_code == 422, f"test_validation failed: result={res.json()}" + assert res.json()["status"] == "error" + assert res.json()["error"]["code"] == "REQUEST_VALIDATION_ERROR" diff --git a/inference/test/test_text_embedding.py b/inference/test/test_text_embedding.py index e1db1205..c8656254 100644 --- a/inference/test/test_text_embedding.py +++ b/inference/test/test_text_embedding.py @@ -305,3 +305,22 @@ async def test_text_embedding_with_error_provider_url(self, test_data, provider_ assert res.status_code == 422, f"test_validation failed: result={res.json()}" assert res.json()["status"] == "error" assert res.json()["error"]["code"] == "REQUEST_VALIDATION_ERROR" + + @pytest.mark.test_id("inference_011") + @pytest.mark.asyncio + @pytest.mark.parametrize("provider_url", Config.PROVIDER_URL_BLACK_LIST) + async def test_text_embedding_with_error_proxy(self, provider_url): + model_schema_id = "openai/text-embedding-3-large-256" + data = {"model_schema_id": model_schema_id} + data.update(self.single_text) + data.update({"proxy": provider_url}) + try: + res = await asyncio.wait_for(text_embedding(data), timeout=120) + except asyncio.TimeoutError: + pytest.skip("Skipping test due to timeout after 2 minutes.") + if is_provider_service_error(res): + pytest.skip(f"Skip the test case with provider service error.") + assert res.status_code == 422, f"test_validation failed: result={res.json()}" + assert res.json()["status"] == "error" + assert res.json()["error"]["code"] == "REQUEST_VALIDATION_ERROR" + await asyncio.sleep(1) diff --git a/inference/test/test_validation.py b/inference/test/test_validation.py index c3556b68..40e80023 100644 --- a/inference/test/test_validation.py +++ b/inference/test/test_validation.py @@ -241,3 +241,33 @@ async def test_validation_with_error_provider_url(self, test_data, provider_url) assert res.json()["status"] == "error" assert res.json()["error"]["code"] == "REQUEST_VALIDATION_ERROR" await asyncio.sleep(1) + + @pytest.mark.parametrize("test_data", generate_test_cases_for_validation(), ids=lambda d: d["model_schema_id"]) + @pytest.mark.parametrize("provider_url", Config.PROVIDER_URL_BLACK_LIST) + @pytest.mark.asyncio + @pytest.mark.test_id("inference_028") + async def test_validation_with_error_proxy(self, test_data, provider_url): + model_schema_id = test_data["model_schema_id"] + if "openai" not in test_data["model_schema_id"]: + pytest.skip("Test not applicable for this model type") + model_type = test_data["model_type"] + credentials = { + key: provider_credentials.aes_decrypt(test_data["credentials"][key]) + for key in test_data["credentials"].keys() + } + custom_headers = {"Helicone-Auth": f"Bearer {Config.HELICONE_API_KEY}"} + request_data = { + "model_schema_id": model_schema_id, + "model_type": model_type, + "credentials": credentials, + "proxy": provider_url, + "custom_headers": custom_headers, + } + try: + res = await asyncio.wait_for(verify_credentials(request_data), timeout=120) + except asyncio.TimeoutError: + pytest.skip("Skipping test due to timeout after 2 minutes.") + assert res.status_code == 422, f"test_validation failed: result={res.json()}" + assert res.json()["status"] == "error" + assert res.json()["error"]["code"] == "REQUEST_VALIDATION_ERROR" + await asyncio.sleep(1)