diff --git a/examples/responses/web_search_domain_filtering.py b/examples/responses/web_search_domain_filtering.py new file mode 100644 index 0000000000..4501eb1303 --- /dev/null +++ b/examples/responses/web_search_domain_filtering.py @@ -0,0 +1,25 @@ +from openai import OpenAI + +client = OpenAI() + +# Example with domain filtering +response = client.responses.create( + model="gpt-4o", + tools=[ + { + "type": "web_search_preview", + "user_location": { + "type": "approximate", + "country": "US", + "city": "San Francisco", + }, + # Include only academic and official sources + "include_domains": ["arxiv.org", "openai.com", "nature.com", "*.edu", "*.gov"], + # Exclude social media and forums + "exclude_domains": ["medium.com", "reddit.com", "quora.com"] + } + ], + input="Latest AI research papers", +) + +print(response.output_text) \ No newline at end of file diff --git a/src/openai/_utils/__init__.py b/src/openai/_utils/__init__.py index bd01c088dc..58e33df111 100644 --- a/src/openai/_utils/__init__.py +++ b/src/openai/_utils/__init__.py @@ -1,6 +1,7 @@ from ._logs import SensitiveHeadersFilter as SensitiveHeadersFilter from ._sync import asyncify as asyncify from ._proxy import LazyProxy as LazyProxy +from ._domain_validator import DomainValidator as DomainValidator from ._utils import ( flatten as flatten, is_dict as is_dict, diff --git a/src/openai/_utils/_domain_validator.py b/src/openai/_utils/_domain_validator.py new file mode 100644 index 0000000000..545a539bd9 --- /dev/null +++ b/src/openai/_utils/_domain_validator.py @@ -0,0 +1,61 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +import re +from typing import List, Optional +from typing_extensions import Literal + +__all__ = ["DomainValidator"] + + +class DomainValidator: + """Utility class for validating domain formats.""" + + DOMAIN_PATTERN = re.compile( + r'^(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]{2,}$' + ) + + WILDCARD_DOMAIN_PATTERN = re.compile( + r'^\*(?:\.(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]{2,})?$' + ) + + @classmethod + def validate_domain(cls, domain: str) -> bool: + """Validate a single domain format. + + Args: + domain: The domain to validate (e.g., "example.com" or "*.example.com") + + Returns: + True if the domain format is valid, False otherwise + """ + if not domain or not isinstance(domain, str): + return False + + # Check for wildcard domains + if domain.startswith('*.'): + return bool(cls.WILDCARD_DOMAIN_PATTERN.match(domain)) + + # Check for regular domains + return bool(cls.DOMAIN_PATTERN.match(domain)) + + @classmethod + def validate_domains(cls, domains: List[str]) -> List[str]: + """Validate a list of domains and return only valid ones. + + Args: + domains: List of domains to validate + + Returns: + List of valid domains + """ + if not domains: + return [] + + valid_domains = [] + for domain in domains: + if cls.validate_domain(domain): + valid_domains.append(domain) + + return valid_domains \ No newline at end of file diff --git a/src/openai/types/responses/web_search_tool.py b/src/openai/types/responses/web_search_tool.py index a6bf951145..f7b8f61ce4 100644 --- a/src/openai/types/responses/web_search_tool.py +++ b/src/openai/types/responses/web_search_tool.py @@ -1,6 +1,6 @@ # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. -from typing import Optional +from typing import List, Optional from typing_extensions import Literal from ..._models import BaseModel @@ -47,3 +47,15 @@ class WebSearchTool(BaseModel): user_location: Optional[UserLocation] = None """The user's location.""" + + include_domains: Optional[List[str]] = None + """List of domains to limit search results to. + + Example: ["arxiv.org", "openai.com", "nature.com"] + """ + + exclude_domains: Optional[List[str]] = None + """List of domains to exclude from search results. + + Example: ["medium.com", "reddit.com"] + """ diff --git a/src/openai/types/responses/web_search_tool_param.py b/src/openai/types/responses/web_search_tool_param.py index d0335c01a3..c0476977e7 100644 --- a/src/openai/types/responses/web_search_tool_param.py +++ b/src/openai/types/responses/web_search_tool_param.py @@ -47,3 +47,15 @@ class WebSearchToolParam(TypedDict, total=False): user_location: Optional[UserLocation] """The user's location.""" + + include_domains: Optional[List[str]] + """List of domains to limit search results to. + + Example: ["arxiv.org", "openai.com", "nature.com"] + """ + + exclude_domains: Optional[List[str]] + """List of domains to exclude from search results. + + Example: ["medium.com", "reddit.com"] + """ diff --git a/tests/test_utils/test_domain_validator.py b/tests/test_utils/test_domain_validator.py new file mode 100644 index 0000000000..1150884274 --- /dev/null +++ b/tests/test_utils/test_domain_validator.py @@ -0,0 +1,62 @@ +import pytest +from openai._utils import DomainValidator + + +class TestDomainValidator: + def test_validate_valid_domains(self): + """Test validation of valid domain formats.""" + valid_domains = [ + "example.com", + "sub.example.com", + "openai.com", + "arxiv.org", + "nature.com", + "*.example.com", + "*.edu", + "*.gov" + ] + + for domain in valid_domains: + assert DomainValidator.validate_domain(domain), f"Domain {domain} should be valid" + + def test_validate_invalid_domains(self): + """Test validation of invalid domain formats.""" + invalid_domains = [ + "", + "invalid", + "example..com", + ".example.com", + "example.", + "example.com.", + "https://example.com", + "http://example.com", + "example.com/path", + "example.com?query=param", + "*.invalid*", + "*.", + "*" + ] + + for domain in invalid_domains: + assert not DomainValidator.validate_domain(domain), f"Domain {domain} should be invalid" + + def test_validate_domains_list(self): + """Test validation of a list of domains.""" + mixed_domains = [ + "example.com", # valid + "invalid", # invalid + "openai.com", # valid + "https://bad.com", # invalid + "*.edu", # valid + "" # invalid + ] + + expected_valid = ["example.com", "openai.com", "*.edu"] + actual_valid = DomainValidator.validate_domains(mixed_domains) + + assert actual_valid == expected_valid + + def test_validate_empty_domains_list(self): + """Test validation of an empty domains list.""" + assert DomainValidator.validate_domains([]) == [] + assert DomainValidator.validate_domains(None) == [] \ No newline at end of file