diff --git a/.php-cs-fixer.dist.php b/.php-cs-fixer.dist.php index c050a5711..b6d2355f0 100644 --- a/.php-cs-fixer.dist.php +++ b/.php-cs-fixer.dist.php @@ -39,6 +39,25 @@ 'header' => implode('', $fileHeaderParts), ], 'php_unit_test_case_static_method_calls' => ['call_type' => 'this'], + 'ordered_class_elements' => [ + 'order' => [ + 'use_trait', + 'case', + 'constant_public', + 'constant_protected', + 'constant_private', + 'property_public', + 'property_protected', + 'property_private', + 'construct', + 'destruct', + 'magic', + 'phpunit', + 'method_public', + 'method_protected', + 'method_private', + ], + ], ]) ->setRiskyAllowed(true) ->setFinder( diff --git a/demo/.env b/demo/.env index 52c6ea228..9dceefa20 100644 --- a/demo/.env +++ b/demo/.env @@ -20,4 +20,5 @@ APP_SECRET=ccb9dca72dce53c683eaaf775bfdb253 ###< symfony/framework-bundle ### CHROMADB_HOST=chromadb +CHROMADB_PORT=8080 OPENAI_API_KEY=sk-... diff --git a/demo/README.md b/demo/README.md index 9641bbbf7..c41269847 100644 --- a/demo/README.md +++ b/demo/README.md @@ -39,11 +39,12 @@ Checkout the repository, start the docker environment and install dependencies: ```shell git clone git@github.com:symfony/ai-demo.git cd ai-demo +composer install docker compose up -d -docker compose run composer install +symfony serve -d ``` -Now you should be able to open https://localhost/ in your browser, +Now you should be able to open https://localhost:8000/ in your browser, and the chatbot UI should be available for you to start chatting. > [!NOTE] @@ -61,7 +62,7 @@ echo "OPENAI_API_KEY='sk-...'" > .env.local Verify the success of this step by running the following command: ```shell -docker compose exec app bin/console debug:dotenv +symfony console debug:dotenv ``` You should be able to see the `OPENAI_API_KEY` in the list of environment variables. @@ -73,13 +74,13 @@ The [Chroma DB](https://www.trychroma.com/) is a vector store that is used to st To initialize the Chroma DB, you need to run the following command: ```shell -docker compose exec app bin/console app:blog:embed -vv +symfony console app:blog:embed -vv ``` Now you should be able to run the test command and get some results: ```shell -docker compose exec app bin/console app:blog:query +symfony console app:blog:query ``` **Don't forget to set up the project in your favorite IDE or editor.** @@ -115,7 +116,7 @@ To add the server, add the following configuration to your MCP Client's settings You can test the MCP server by running the following command to start the MCP client: ```shell -php bin/console mcp:server +symfony console mcp:server ``` Then, paste `{"method":"tools/list","jsonrpc":"2.0","id":1}` to list the tools available on the MCP server. diff --git a/demo/compose.yaml b/demo/compose.yaml index 7938910f8..c7f6e98f2 100644 --- a/demo/compose.yaml +++ b/demo/compose.yaml @@ -1,19 +1,8 @@ services: - app: - image: dunglas/frankenphp - volumes: - - ./:/app - ports: - - 443:443 - tty: true - - composer: - image: composer:latest - volumes: - - ./:/app - chromadb: image: chromadb/chroma:0.5.23 + ports: + - '8080:8000' volumes: - ./chromadb:/chroma/chroma environment: diff --git a/demo/composer.json b/demo/composer.json index 2437041ae..afc5709b6 100644 --- a/demo/composer.json +++ b/demo/composer.json @@ -71,9 +71,6 @@ "symfony/flex": true, "symfony/runtime": true }, - "platform": { - "php": "8.4.7" - }, "sort-packages": true }, "extra": { diff --git a/demo/config/packages/chromadb.yaml b/demo/config/packages/chromadb.yaml index 7a10c43b4..4891d925d 100644 --- a/demo/config/packages/chromadb.yaml +++ b/demo/config/packages/chromadb.yaml @@ -2,6 +2,7 @@ services: Codewithkyrian\ChromaDB\Factory: calls: - withHost: ['%env(CHROMADB_HOST)%'] + - withPort: ['%env(CHROMADB_PORT)%'] Codewithkyrian\ChromaDB\Client: factory: ['@Codewithkyrian\ChromaDB\Factory', 'connect'] diff --git a/examples/.env b/examples/.env index e26ac34fa..e25d79d0a 100644 --- a/examples/.env +++ b/examples/.env @@ -52,6 +52,10 @@ TAVILY_API_KEY= # For using Brave (tool) BRAVE_API_KEY= +# For using Firecrawl (tool) +FIRECRAWL_HOST=https://api.firecrawl.dev +FIRECRAWL_API_KEY= + # For using MongoDB Atlas (store) MONGODB_URI= @@ -96,3 +100,10 @@ NEO4J_HOST=http://127.0.0.1:7474 NEO4J_DATABASE=neo4j NEO4J_USERNAME=neo4j NEO4J_PASSWORD=symfonyai + +# Typesense +TYPESENSE_HOST=http://127.0.0.1:8108 +TYPESENSE_API_KEY=changeMe + +# Cerebras +CEREBRAS_API_KEY= diff --git a/examples/cerebras/chat.php b/examples/cerebras/chat.php new file mode 100644 index 000000000..fb40b2153 --- /dev/null +++ b/examples/cerebras/chat.php @@ -0,0 +1,29 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +use Symfony\AI\Agent\Agent; +use Symfony\AI\Platform\Bridge\Cerebras\Model; +use Symfony\AI\Platform\Bridge\Cerebras\PlatformFactory; +use Symfony\AI\Platform\Message\Message; +use Symfony\AI\Platform\Message\MessageBag; + +require_once dirname(__DIR__).'/bootstrap.php'; + +$platform = PlatformFactory::create(env('CEREBRAS_API_KEY'), http_client()); + +$agent = new Agent($platform, new Model(), logger: logger()); +$messages = new MessageBag( + Message::forSystem('You are a helpful assistant.'), + Message::ofUser('How is the weather in Tokyo today?'), +); +$result = $agent->call($messages); + +echo $result->getContent().\PHP_EOL; diff --git a/examples/cerebras/stream.php b/examples/cerebras/stream.php new file mode 100644 index 000000000..6b6d694bb --- /dev/null +++ b/examples/cerebras/stream.php @@ -0,0 +1,36 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +use Symfony\AI\Agent\Agent; +use Symfony\AI\Platform\Bridge\Cerebras\Model; +use Symfony\AI\Platform\Bridge\Cerebras\PlatformFactory; +use Symfony\AI\Platform\Message\Message; +use Symfony\AI\Platform\Message\MessageBag; + +require_once dirname(__DIR__).'/bootstrap.php'; + +$platform = PlatformFactory::create(env('CEREBRAS_API_KEY'), http_client()); + +$agent = new Agent($platform, new Model(), logger: logger()); + +$messages = new MessageBag( + Message::forSystem('You are an expert in places and geography who always responds concisely.'), + Message::ofUser('What are the top three destinations in France?'), +); + +$result = $agent->call($messages, [ + 'stream' => true, +]); + +foreach ($result->getContent() as $word) { + echo $word; +} +echo \PHP_EOL; diff --git a/examples/compose.yaml b/examples/compose.yaml index bd774e214..7be0538db 100644 --- a/examples/compose.yaml +++ b/examples/compose.yaml @@ -45,3 +45,16 @@ services: ports: - '7474:7474' - '7687:7687' + + typesense: + image: typesense/typesense:29.0 + environment: + TYPESENSE_API_KEY: '${TYPESENSE_API_KEY:-changeMe}' + TYPESENSE_DATA_DIR: '/data' + volumes: + - typesense_data:/data + ports: + - '8108:8108' + +volumes: + typesense_data: diff --git a/examples/gemini/server-tools.php b/examples/gemini/server-tools.php index 132354904..2a47ebaf8 100644 --- a/examples/gemini/server-tools.php +++ b/examples/gemini/server-tools.php @@ -27,7 +27,7 @@ $toolbox = new Toolbox([new Clock()], logger: logger()); $processor = new AgentProcessor($toolbox); -$agent = new Agent($platform, $llm, logger: logger()); +$agent = new Agent($platform, $llm, [$processor], [$processor], logger()); $messages = new MessageBag( Message::ofUser( diff --git a/examples/ollama/embeddings.php b/examples/ollama/embeddings.php new file mode 100644 index 000000000..822704d8c --- /dev/null +++ b/examples/ollama/embeddings.php @@ -0,0 +1,25 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +use Symfony\AI\Platform\Bridge\Ollama\Ollama; +use Symfony\AI\Platform\Bridge\Ollama\PlatformFactory; + +require_once dirname(__DIR__).'/bootstrap.php'; + +$platform = PlatformFactory::create(env('OLLAMA_HOST_URL'), http_client()); + +$response = $platform->invoke(new Ollama(Ollama::NOMIC_EMBED_TEXT), <<asVectors()[0]->getDimensions().\PHP_EOL; diff --git a/examples/ollama/structured-output-math.php b/examples/ollama/structured-output-math.php new file mode 100644 index 000000000..02003e4f1 --- /dev/null +++ b/examples/ollama/structured-output-math.php @@ -0,0 +1,33 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +use Symfony\AI\Agent\Agent; +use Symfony\AI\Agent\StructuredOutput\AgentProcessor; +use Symfony\AI\Fixtures\StructuredOutput\MathReasoning; +use Symfony\AI\Platform\Bridge\Ollama\Ollama; +use Symfony\AI\Platform\Bridge\Ollama\PlatformFactory; +use Symfony\AI\Platform\Message\Message; +use Symfony\AI\Platform\Message\MessageBag; + +require_once dirname(__DIR__).'/bootstrap.php'; + +$platform = PlatformFactory::create(env('OLLAMA_HOST_URL'), http_client()); +$model = new Ollama(); + +$processor = new AgentProcessor(); +$agent = new Agent($platform, $model, [$processor], [$processor], logger()); +$messages = new MessageBag( + Message::forSystem('You are a helpful math tutor. Guide the user through the solution step by step.'), + Message::ofUser('how can I solve 8x + 7 = -23'), +); +$result = $agent->call($messages, ['output_structure' => MathReasoning::class]); + +dump($result->getContent()); diff --git a/examples/rag/cache.php b/examples/rag/cache.php new file mode 100644 index 000000000..46d3d704c --- /dev/null +++ b/examples/rag/cache.php @@ -0,0 +1,63 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +use Symfony\AI\Agent\Agent; +use Symfony\AI\Agent\Toolbox\AgentProcessor; +use Symfony\AI\Agent\Toolbox\Tool\SimilaritySearch; +use Symfony\AI\Agent\Toolbox\Toolbox; +use Symfony\AI\Fixtures\Movies; +use Symfony\AI\Platform\Bridge\OpenAi\Embeddings; +use Symfony\AI\Platform\Bridge\OpenAi\Gpt; +use Symfony\AI\Platform\Bridge\OpenAi\PlatformFactory; +use Symfony\AI\Platform\Message\Message; +use Symfony\AI\Platform\Message\MessageBag; +use Symfony\AI\Store\CacheStore; +use Symfony\AI\Store\Document\Metadata; +use Symfony\AI\Store\Document\TextDocument; +use Symfony\AI\Store\Document\Vectorizer; +use Symfony\AI\Store\Indexer; +use Symfony\Component\Cache\Adapter\ArrayAdapter; +use Symfony\Component\Uid\Uuid; + +require_once dirname(__DIR__).'/bootstrap.php'; + +// initialize the store +$store = new CacheStore(new ArrayAdapter()); + +// create embeddings and documents +foreach (Movies::all() as $i => $movie) { + $documents[] = new TextDocument( + id: Uuid::v4(), + content: 'Title: '.$movie['title'].\PHP_EOL.'Director: '.$movie['director'].\PHP_EOL.'Description: '.$movie['description'], + metadata: new Metadata($movie), + ); +} + +// create embeddings for documents +$platform = PlatformFactory::create(env('OPENAI_API_KEY'), http_client()); +$vectorizer = new Vectorizer($platform, $embeddings = new Embeddings()); +$indexer = new Indexer($vectorizer, $store, logger()); +$indexer->index($documents); + +$model = new Gpt(Gpt::GPT_4O_MINI); + +$similaritySearch = new SimilaritySearch($platform, $embeddings, $store); +$toolbox = new Toolbox([$similaritySearch], logger: logger()); +$processor = new AgentProcessor($toolbox); +$agent = new Agent($platform, $model, [$processor], [$processor], logger()); + +$messages = new MessageBag( + Message::forSystem('Please answer all user questions only using SimilaritySearch function.'), + Message::ofUser('Which movie fits the theme of the mafia?') +); +$result = $agent->call($messages); + +echo $result->getContent().\PHP_EOL; diff --git a/examples/rag/typesense.php b/examples/rag/typesense.php new file mode 100644 index 000000000..f19dae60e --- /dev/null +++ b/examples/rag/typesense.php @@ -0,0 +1,71 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +use Symfony\AI\Agent\Agent; +use Symfony\AI\Agent\Toolbox\AgentProcessor; +use Symfony\AI\Agent\Toolbox\Tool\SimilaritySearch; +use Symfony\AI\Agent\Toolbox\Toolbox; +use Symfony\AI\Fixtures\Movies; +use Symfony\AI\Platform\Bridge\OpenAi\Embeddings; +use Symfony\AI\Platform\Bridge\OpenAi\Gpt; +use Symfony\AI\Platform\Bridge\OpenAi\PlatformFactory; +use Symfony\AI\Platform\Message\Message; +use Symfony\AI\Platform\Message\MessageBag; +use Symfony\AI\Store\Bridge\Typesense\Store; +use Symfony\AI\Store\Document\Metadata; +use Symfony\AI\Store\Document\TextDocument; +use Symfony\AI\Store\Document\Vectorizer; +use Symfony\AI\Store\Indexer; +use Symfony\Component\Uid\Uuid; + +require_once dirname(__DIR__).'/bootstrap.php'; + +// initialize the store +$store = new Store( + httpClient: http_client(), + endpointUrl: env('TYPESENSE_HOST'), + apiKey: env('TYPESENSE_API_KEY'), + collection: 'movies', +); + +// initialize the index +$store->initialize(); + +// create embeddings and documents +$documents = []; +foreach (Movies::all() as $i => $movie) { + $documents[] = new TextDocument( + id: Uuid::v4(), + content: 'Title: '.$movie['title'].\PHP_EOL.'Director: '.$movie['director'].\PHP_EOL.'Description: '.$movie['description'], + metadata: new Metadata($movie), + ); +} + +// create embeddings for documents +$platform = PlatformFactory::create(env('OPENAI_API_KEY'), http_client()); +$vectorizer = new Vectorizer($platform, $embeddings = new Embeddings()); +$indexer = new Indexer($vectorizer, $store, logger()); +$indexer->index($documents); + +$model = new Gpt(Gpt::GPT_4O_MINI); + +$similaritySearch = new SimilaritySearch($platform, $embeddings, $store); +$toolbox = new Toolbox([$similaritySearch], logger: logger()); +$processor = new AgentProcessor($toolbox); +$agent = new Agent($platform, $model, [$processor], [$processor], logger()); + +$messages = new MessageBag( + Message::forSystem('Please answer all user questions only using SimilaritySearch function.'), + Message::ofUser('Which movie fits the theme of technology?') +); +$result = $agent->call($messages); + +echo $result->getContent().\PHP_EOL; diff --git a/examples/toolbox/firecrawl-crawl.php b/examples/toolbox/firecrawl-crawl.php new file mode 100644 index 000000000..75a220f48 --- /dev/null +++ b/examples/toolbox/firecrawl-crawl.php @@ -0,0 +1,40 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +use Symfony\AI\Agent\Agent; +use Symfony\AI\Agent\Toolbox\AgentProcessor; +use Symfony\AI\Agent\Toolbox\Tool\Firecrawl; +use Symfony\AI\Agent\Toolbox\Toolbox; +use Symfony\AI\Platform\Bridge\OpenAi\Gpt; +use Symfony\AI\Platform\Bridge\OpenAi\PlatformFactory; +use Symfony\AI\Platform\Message\Message; +use Symfony\AI\Platform\Message\MessageBag; + +require_once dirname(__DIR__) . '/bootstrap.php'; + +$platform = PlatformFactory::create(env('OPENAI_API_KEY'), http_client()); +$model = new Gpt(Gpt::GPT_4O_MINI); + +$firecrawl = new Firecrawl( + http_client(), + env('FIRECRAWL_API_KEY'), + env('FIRECRAWL_HOST'), +); + +$toolbox = new Toolbox([$firecrawl], logger: logger()); +$toolProcessor = new AgentProcessor($toolbox); + +$agent = new Agent($platform, $model, inputProcessors: [$toolProcessor], outputProcessors: [$toolProcessor]); + +$messages = new MessageBag(Message::ofUser('Crawl the following URL: https://symfony.com/doc/current/setup.html then resume it in less than 200 words.')); +$result = $agent->call($messages); + +echo $result->getContent() . \PHP_EOL; diff --git a/examples/toolbox/firecrawl-map.php b/examples/toolbox/firecrawl-map.php new file mode 100644 index 000000000..a791e708b --- /dev/null +++ b/examples/toolbox/firecrawl-map.php @@ -0,0 +1,40 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +use Symfony\AI\Agent\Agent; +use Symfony\AI\Agent\Toolbox\AgentProcessor; +use Symfony\AI\Agent\Toolbox\Tool\Firecrawl; +use Symfony\AI\Agent\Toolbox\Toolbox; +use Symfony\AI\Platform\Bridge\OpenAi\Gpt; +use Symfony\AI\Platform\Bridge\OpenAi\PlatformFactory; +use Symfony\AI\Platform\Message\Message; +use Symfony\AI\Platform\Message\MessageBag; + +require_once dirname(__DIR__) . '/bootstrap.php'; + +$platform = PlatformFactory::create(env('OPENAI_API_KEY'), http_client()); +$model = new Gpt(Gpt::GPT_4O_MINI); + +$firecrawl = new Firecrawl( + http_client(), + env('FIRECRAWL_API_KEY'), + env('FIRECRAWL_HOST'), +); + +$toolbox = new Toolbox([$firecrawl], logger: logger()); +$toolProcessor = new AgentProcessor($toolbox); + +$agent = new Agent($platform, $model, inputProcessors: [$toolProcessor], outputProcessors: [$toolProcessor]); + +$messages = new MessageBag(Message::ofUser('Retrieve all the links from https://symfony.com then list only the ones related to the Messenger component.')); +$result = $agent->call($messages); + +echo $result->getContent() . \PHP_EOL; diff --git a/examples/toolbox/firecrawl-scrape.php b/examples/toolbox/firecrawl-scrape.php new file mode 100644 index 000000000..1c2fa872c --- /dev/null +++ b/examples/toolbox/firecrawl-scrape.php @@ -0,0 +1,40 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +use Symfony\AI\Agent\Agent; +use Symfony\AI\Agent\Toolbox\AgentProcessor; +use Symfony\AI\Agent\Toolbox\Tool\Firecrawl; +use Symfony\AI\Agent\Toolbox\Toolbox; +use Symfony\AI\Platform\Bridge\OpenAi\Gpt; +use Symfony\AI\Platform\Bridge\OpenAi\PlatformFactory; +use Symfony\AI\Platform\Message\Message; +use Symfony\AI\Platform\Message\MessageBag; + +require_once dirname(__DIR__) . '/bootstrap.php'; + +$platform = PlatformFactory::create(env('OPENAI_API_KEY'), http_client()); +$model = new Gpt(Gpt::GPT_4O_MINI); + +$firecrawl = new Firecrawl( + http_client(), + env('FIRECRAWL_API_KEY'), + env('FIRECRAWL_HOST'), +); + +$toolbox = new Toolbox([$firecrawl], logger: logger()); +$toolProcessor = new AgentProcessor($toolbox); + +$agent = new Agent($platform, $model, inputProcessors: [$toolProcessor], outputProcessors: [$toolProcessor]); + +$messages = new MessageBag(Message::ofUser('Scrape the following URL: https://symfony.com/doc/current/setup.html then resume it in less than 200 words.')); +$result = $agent->call($messages); + +echo $result->getContent() . \PHP_EOL; diff --git a/src/agent/doc/index.rst b/src/agent/doc/index.rst index 1c24b4865..377cf4575 100644 --- a/src/agent/doc/index.rst +++ b/src/agent/doc/index.rst @@ -553,9 +553,9 @@ useful when certain interactions shouldn't be influenced by the memory context:: .. _`Wikipedia Tool`: https://github.com/symfony/ai/blob/main/examples/openai/toolcall-stream.php .. _`YouTube Transcriber Tool`: https://github.com/symfony/ai/blob/main/examples/openai/toolcall.php .. _`Store Component`: https://github.com/symfony/ai-store -.. _`RAG with MongoDB`: https://github.com/symfony/ai/blob/main/examples/store/mongodb-similarity-search.php -.. _`RAG with Pinecone`: https://github.com/symfony/ai/blob/main/examples/store/pinecone-similarity-search.php +.. _`RAG with MongoDB`: https://github.com/symfony/ai/blob/main/examples/rag/mongodb.php +.. _`RAG with Pinecone`: https://github.com/symfony/ai/blob/main/examples/rag/pinecone.php .. _`Structured Output with PHP class`: https://github.com/symfony/ai/blob/main/examples/openai/structured-output-math.php .. _`Structured Output with array`: https://github.com/symfony/ai/blob/main/examples/openai/structured-output-clock.php .. _`Chat with static memory`: https://github.com/symfony/ai/blob/main/examples/memory/static.php -.. _`Chat with embedding search memory`: https://github.com/symfony/ai/blob/main/memory/mariadb.php +.. _`Chat with embedding search memory`: https://github.com/symfony/ai/blob/main/examples/memory/mariadb.php diff --git a/src/agent/src/Toolbox/Tool/Firecrawl.php b/src/agent/src/Toolbox/Tool/Firecrawl.php new file mode 100644 index 000000000..ec5f2570c --- /dev/null +++ b/src/agent/src/Toolbox/Tool/Firecrawl.php @@ -0,0 +1,125 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Agent\Toolbox\Tool; + +use Symfony\AI\Agent\Toolbox\Attribute\AsTool; +use Symfony\Contracts\HttpClient\HttpClientInterface; +use Symfony\Contracts\HttpClient\ResponseInterface; + +/** + * @author Guillaume Loulier + * + * @see https://www.firecrawl.dev/ + */ +#[AsTool('firecrawl_scrape', description: 'Allow to scrape website using url', method: 'scrape')] +#[AsTool('firecrawl_crawl', description: 'Allow to crawl website using url', method: 'crawl')] +#[AsTool('firecrawl_map', description: 'Allow to retrieve all urls from a website using url', method: 'map')] +final readonly class Firecrawl +{ + public function __construct( + private HttpClientInterface $httpClient, + #[\SensitiveParameter] private string $apiKey, + private string $endpoint, + ) { + } + + /** + * @return array{ + * url: string, + * markdown: string, + * html: string, + * } + */ + public function scrape(string $url): array + { + $response = $this->httpClient->request('POST', \sprintf('%s/v1/scrape', $this->endpoint), [ + 'auth_bearer' => $this->apiKey, + 'json' => [ + 'url' => $url, + 'formats' => ['markdown', 'html'], + ], + ]); + + $scrapingPayload = $response->toArray(); + + return [ + 'url' => $url, + 'markdown' => $scrapingPayload['data']['markdown'], + 'html' => $scrapingPayload['data']['html'], + ]; + } + + /** + * @return array|array{} + */ + public function crawl(string $url): array + { + $response = $this->httpClient->request('POST', \sprintf('%s/v1/crawl', $this->endpoint), [ + 'auth_bearer' => $this->apiKey, + 'json' => [ + 'url' => $url, + 'scrapeOptions' => [ + 'formats' => ['markdown', 'html'], + ], + ], + ]); + + $crawlingPayload = $response->toArray(); + + $scrapingStatusRequest = fn (array $crawlingPayload): ResponseInterface => $this->httpClient->request('GET', \sprintf('%s/v1/crawl/%s', $this->endpoint, $crawlingPayload['id']), [ + 'auth_bearer' => $this->apiKey, + ]); + + while ('scraping' === $scrapingStatusRequest($crawlingPayload)->toArray()['status']) { + usleep(500); + } + + $scrapingPayload = $this->httpClient->request('GET', \sprintf('%s/v1/crawl/%s', $this->endpoint, $crawlingPayload['id']), [ + 'auth_bearer' => $this->apiKey, + ]); + + $finalPayload = $scrapingPayload->toArray(); + + return array_map(static fn (array $scrapedItem) => [ + 'url' => $scrapedItem['metadata']['og:url'], + 'markdown' => $scrapedItem['markdown'], + 'html' => $scrapedItem['html'], + ], $finalPayload['data']); + } + + /** + * @return array{ + * url: string, + * links: array, + * } + */ + public function map(string $url): array + { + $response = $this->httpClient->request('POST', \sprintf('%s/v1/map', $this->endpoint), [ + 'auth_bearer' => $this->apiKey, + 'json' => [ + 'url' => $url, + ], + ]); + + $mappingPayload = $response->toArray(); + + return [ + 'url' => $url, + 'links' => $mappingPayload['links'], + ]; + } +} diff --git a/src/agent/tests/InputProcessor/ModelOverrideInputProcessorTest.php b/src/agent/tests/InputProcessor/ModelOverrideInputProcessorTest.php index 39a6fd74f..6696d43ad 100644 --- a/src/agent/tests/InputProcessor/ModelOverrideInputProcessorTest.php +++ b/src/agent/tests/InputProcessor/ModelOverrideInputProcessorTest.php @@ -57,8 +57,8 @@ public function testProcessInputWithoutModelOption() public function testProcessInputWithInvalidModelOption() { - self::expectException(InvalidArgumentException::class); - self::expectExceptionMessage('Option "model" must be an instance of "Symfony\AI\Platform\Model".'); + $this->expectException(InvalidArgumentException::class); + $this->expectExceptionMessage('Option "model" must be an instance of "Symfony\AI\Platform\Model".'); $gpt = new Gpt(); $model = new MessageBag(); diff --git a/src/agent/tests/StructuredOutput/AgentProcessorTest.php b/src/agent/tests/StructuredOutput/AgentProcessorTest.php index 8c1a380be..9bff7062e 100644 --- a/src/agent/tests/StructuredOutput/AgentProcessorTest.php +++ b/src/agent/tests/StructuredOutput/AgentProcessorTest.php @@ -65,7 +65,7 @@ public function testProcessInputWithoutOutputStructure() public function testProcessInputThrowsExceptionWhenLlmDoesNotSupportStructuredOutput() { - self::expectException(MissingModelSupportException::class); + $this->expectException(MissingModelSupportException::class); $processor = new AgentProcessor(new ConfigurableResponseFormatFactory()); diff --git a/src/agent/tests/Toolbox/AgentProcessorTest.php b/src/agent/tests/Toolbox/AgentProcessorTest.php index b58572cca..5f0cd6408 100644 --- a/src/agent/tests/Toolbox/AgentProcessorTest.php +++ b/src/agent/tests/Toolbox/AgentProcessorTest.php @@ -90,7 +90,7 @@ public function testProcessInputWithRegisteredToolsButToolOverride() public function testProcessInputWithUnsupportedToolCallingWillThrowException() { - self::expectException(MissingModelSupportException::class); + $this->expectException(MissingModelSupportException::class); $model = new Model('gpt-3'); $processor = new AgentProcessor($this->createStub(ToolboxInterface::class)); diff --git a/src/agent/tests/Toolbox/MetadataFactory/ChainFactoryTest.php b/src/agent/tests/Toolbox/MetadataFactory/ChainFactoryTest.php index bfabaad99..ee11b2a38 100644 --- a/src/agent/tests/Toolbox/MetadataFactory/ChainFactoryTest.php +++ b/src/agent/tests/Toolbox/MetadataFactory/ChainFactoryTest.php @@ -47,16 +47,16 @@ protected function setUp(): void public function testTestGetMetadataNotExistingClass() { - self::expectException(ToolException::class); - self::expectExceptionMessage('The reference "NoClass" is not a valid tool.'); + $this->expectException(ToolException::class); + $this->expectExceptionMessage('The reference "NoClass" is not a valid tool.'); iterator_to_array($this->factory->getTool('NoClass')); } public function testTestGetMetadataNotConfiguredClass() { - self::expectException(ToolConfigurationException::class); - self::expectExceptionMessage(\sprintf('Method "foo" not found in tool "%s".', ToolMisconfigured::class)); + $this->expectException(ToolConfigurationException::class); + $this->expectExceptionMessage(\sprintf('Method "foo" not found in tool "%s".', ToolMisconfigured::class)); iterator_to_array($this->factory->getTool(ToolMisconfigured::class)); } diff --git a/src/agent/tests/Toolbox/MetadataFactory/MemoryFactoryTest.php b/src/agent/tests/Toolbox/MetadataFactory/MemoryFactoryTest.php index 4a3705ae0..f126fcc62 100644 --- a/src/agent/tests/Toolbox/MetadataFactory/MemoryFactoryTest.php +++ b/src/agent/tests/Toolbox/MetadataFactory/MemoryFactoryTest.php @@ -35,8 +35,8 @@ final class MemoryFactoryTest extends TestCase { public function testGetMetadataWithoutTools() { - self::expectException(ToolException::class); - self::expectExceptionMessage('The reference "SomeClass" is not a valid tool.'); + $this->expectException(ToolException::class); + $this->expectExceptionMessage('The reference "SomeClass" is not a valid tool.'); $factory = new MemoryToolFactory(); iterator_to_array($factory->getTool('SomeClass')); // @phpstan-ignore-line Yes, this class does not exist diff --git a/src/agent/tests/Toolbox/MetadataFactory/ReflectionFactoryTest.php b/src/agent/tests/Toolbox/MetadataFactory/ReflectionFactoryTest.php index 0c346a346..b55c11747 100644 --- a/src/agent/tests/Toolbox/MetadataFactory/ReflectionFactoryTest.php +++ b/src/agent/tests/Toolbox/MetadataFactory/ReflectionFactoryTest.php @@ -45,16 +45,16 @@ protected function setUp(): void public function testInvalidReferenceNonExistingClass() { - self::expectException(ToolException::class); - self::expectExceptionMessage('The reference "invalid" is not a valid tool.'); + $this->expectException(ToolException::class); + $this->expectExceptionMessage('The reference "invalid" is not a valid tool.'); iterator_to_array($this->factory->getTool('invalid')); // @phpstan-ignore-line Yes, this class does not exist } public function testWithoutAttribute() { - self::expectException(ToolException::class); - self::expectExceptionMessage(\sprintf('The class "%s" is not a tool, please add %s attribute.', ToolWrong::class, AsTool::class)); + $this->expectException(ToolException::class); + $this->expectExceptionMessage(\sprintf('The class "%s" is not a tool, please add %s attribute.', ToolWrong::class, AsTool::class)); iterator_to_array($this->factory->getTool(ToolWrong::class)); } diff --git a/src/agent/tests/Toolbox/Tool/FirecrawlTest.php b/src/agent/tests/Toolbox/Tool/FirecrawlTest.php new file mode 100644 index 000000000..795cceec5 --- /dev/null +++ b/src/agent/tests/Toolbox/Tool/FirecrawlTest.php @@ -0,0 +1,76 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Agent\Tests\Toolbox\Tool; + +use PHPUnit\Framework\Attributes\CoversClass; +use PHPUnit\Framework\TestCase; +use Symfony\AI\Agent\Toolbox\Tool\Firecrawl; +use Symfony\Component\HttpClient\MockHttpClient; +use Symfony\Component\HttpClient\Response\JsonMockResponse; + +#[CoversClass(Firecrawl::class)] +final class FirecrawlTest extends TestCase +{ + public function testScrape() + { + $httpClient = new MockHttpClient([ + new JsonMockResponse(json_decode(file_get_contents(__DIR__.'/fixtures/firecrawl-scrape.json'), true)), + ]); + + $firecrawl = new Firecrawl($httpClient, 'test', 'https://127.0.0.1:3002'); + + $scrapingResult = $firecrawl->scrape('https://www.symfony.com'); + + $this->assertSame('https://www.symfony.com', $scrapingResult['url']); + $this->assertNotEmpty($scrapingResult['markdown']); + $this->assertNotEmpty($scrapingResult['html']); + $this->assertSame(1, $httpClient->getRequestsCount()); + } + + public function testCrawl() + { + $httpClient = new MockHttpClient([ + new JsonMockResponse(json_decode(file_get_contents(__DIR__.'/fixtures/firecrawl-crawl-wait.json'), true)), + new JsonMockResponse(json_decode(file_get_contents(__DIR__.'/fixtures/firecrawl-crawl-status.json'), true)), + new JsonMockResponse(json_decode(file_get_contents(__DIR__.'/fixtures/firecrawl-crawl-status-done.json'), true)), + new JsonMockResponse(json_decode(file_get_contents(__DIR__.'/fixtures/firecrawl-crawl.json'), true)), + ]); + + $firecrawl = new Firecrawl($httpClient, 'test', 'https://127.0.0.1:3002'); + + $scrapingResult = $firecrawl->crawl('https://www.symfony.com'); + + $this->assertCount(1, $scrapingResult); + $this->assertNotEmpty($scrapingResult[0]); + + $firstItem = $scrapingResult[0]; + $this->assertSame('https://www.symfony.com', $firstItem['url']); + $this->assertNotEmpty($firstItem['markdown']); + $this->assertNotEmpty($firstItem['html']); + $this->assertSame(4, $httpClient->getRequestsCount()); + } + + public function testMap() + { + $httpClient = new MockHttpClient([ + new JsonMockResponse(json_decode(file_get_contents(__DIR__.'/fixtures/firecrawl-map.json'), true)), + ]); + + $firecrawl = new Firecrawl($httpClient, 'test', 'https://127.0.0.1:3002'); + + $mapping = $firecrawl->map('https://www.symfony.com'); + + $this->assertSame('https://www.symfony.com', $mapping['url']); + $this->assertCount(5, $mapping['links']); + $this->assertSame(1, $httpClient->getRequestsCount()); + } +} diff --git a/src/agent/tests/Toolbox/Tool/fixtures/firecrawl-crawl-status-done.json b/src/agent/tests/Toolbox/Tool/fixtures/firecrawl-crawl-status-done.json new file mode 100644 index 000000000..e0a96f2a8 --- /dev/null +++ b/src/agent/tests/Toolbox/Tool/fixtures/firecrawl-crawl-status-done.json @@ -0,0 +1,3 @@ +{ + "status": "completed" +} diff --git a/src/agent/tests/Toolbox/Tool/fixtures/firecrawl-crawl-status.json b/src/agent/tests/Toolbox/Tool/fixtures/firecrawl-crawl-status.json new file mode 100644 index 000000000..a58366a2f --- /dev/null +++ b/src/agent/tests/Toolbox/Tool/fixtures/firecrawl-crawl-status.json @@ -0,0 +1,3 @@ +{ + "status": "scraping" +} diff --git a/src/agent/tests/Toolbox/Tool/fixtures/firecrawl-crawl-wait.json b/src/agent/tests/Toolbox/Tool/fixtures/firecrawl-crawl-wait.json new file mode 100644 index 000000000..1f10e40a6 --- /dev/null +++ b/src/agent/tests/Toolbox/Tool/fixtures/firecrawl-crawl-wait.json @@ -0,0 +1,5 @@ +{ + "success": true, + "id": "123-456-789", + "url": "https://127.0.0.1:3002/v1/crawl/123-456-789" +} diff --git a/src/agent/tests/Toolbox/Tool/fixtures/firecrawl-crawl.json b/src/agent/tests/Toolbox/Tool/fixtures/firecrawl-crawl.json new file mode 100644 index 000000000..16a242d3f --- /dev/null +++ b/src/agent/tests/Toolbox/Tool/fixtures/firecrawl-crawl.json @@ -0,0 +1,23 @@ +{ + "status": "scraping", + "total": 36, + "completed": 10, + "creditsUsed": 10, + "expiresAt": "2024-00-00T00:00:00.000Z", + "next": "https://api.firecrawl.dev/v1/crawl/123-456-789?skip=10", + "data": [ + { + "markdown": "[Firecrawl Docs home page![light logo](https://mintlify.s3-us-west-1.amazonaws.com/firecrawl/logo/light.svg)!...", + "html": "...", + "metadata": { + "title": "Build a 'Chat with website' using Groq Llama 3 | Firecrawl", + "language": "en", + "sourceURL": "https://docs.firecrawl.dev/learn/rag-llama3", + "description": "Learn how to use Firecrawl, Groq Llama 3, and Langchain to build a 'Chat with your website' bot.", + "ogLocaleAlternate": [], + "statusCode": 200, + "og:url": "https://www.symfony.com" + } + } + ] +} diff --git a/src/agent/tests/Toolbox/Tool/fixtures/firecrawl-map.json b/src/agent/tests/Toolbox/Tool/fixtures/firecrawl-map.json new file mode 100644 index 000000000..54aa6e58b --- /dev/null +++ b/src/agent/tests/Toolbox/Tool/fixtures/firecrawl-map.json @@ -0,0 +1,10 @@ +{ + "status": "success", + "links": [ + "https://firecrawl.dev", + "https://www.firecrawl.dev/pricing", + "https://www.firecrawl.dev/blog", + "https://www.firecrawl.dev/playground", + "https://www.firecrawl.dev/smart-crawl" + ] +} diff --git a/src/agent/tests/Toolbox/Tool/fixtures/firecrawl-scrape.json b/src/agent/tests/Toolbox/Tool/fixtures/firecrawl-scrape.json new file mode 100644 index 000000000..08d081e85 --- /dev/null +++ b/src/agent/tests/Toolbox/Tool/fixtures/firecrawl-scrape.json @@ -0,0 +1,22 @@ +{ + "success": true, + "data" : { + "markdown": "Launch Week I is here! [See our Day 2 Release 🚀](https://www.firecrawl.dev/blog/launch-week-i-day-2-doubled-rate-limits)[💥 Get 2 months free...", + "html": "expectException(ToolNotFoundException::class); + $this->expectExceptionMessage('Tool not found for call: foo_bar_baz'); $this->toolbox->execute(new ToolCall('call_1234', 'foo_bar_baz')); } public function testExecuteWithMisconfiguredTool() { - self::expectException(ToolConfigurationException::class); - self::expectExceptionMessage('Method "foo" not found in tool "Symfony\AI\Fixtures\Tool\ToolMisconfigured".'); + $this->expectException(ToolConfigurationException::class); + $this->expectExceptionMessage('Method "foo" not found in tool "Symfony\AI\Fixtures\Tool\ToolMisconfigured".'); $toolbox = new Toolbox([new ToolMisconfigured()], new ReflectionToolFactory()); @@ -171,8 +171,8 @@ public function testExecuteWithMisconfiguredTool() public function testExecuteWithException() { - self::expectException(ToolExecutionException::class); - self::expectExceptionMessage('Execution of tool "tool_exception" failed with error: Tool error.'); + $this->expectException(ToolExecutionException::class); + $this->expectExceptionMessage('Execution of tool "tool_exception" failed with error: Tool error.'); $this->toolbox->execute(new ToolCall('call_1234', 'tool_exception')); } diff --git a/src/ai-bundle/config/options.php b/src/ai-bundle/config/options.php index 8e9102443..91ab8bbe6 100644 --- a/src/ai-bundle/config/options.php +++ b/src/ai-bundle/config/options.php @@ -65,6 +65,11 @@ ->scalarNode('host_url')->defaultValue('http://127.0.0.1:1234')->end() ->end() ->end() + ->arrayNode('ollama') + ->children() + ->scalarNode('host_url')->defaultValue('http://127.0.0.1:11434')->end() + ->end() + ->end() ->end() ->end() ->arrayNode('agent') @@ -156,6 +161,15 @@ ->end() ->end() ->end() + ->arrayNode('cache') + ->normalizeKeys(false) + ->useAttributeAsKey('name') + ->arrayPrototype() + ->children() + ->scalarNode('service')->cannotBeEmpty()->defaultValue('cache.app')->end() + ->end() + ->end() + ->end() ->arrayNode('chroma_db') ->normalizeKeys(false) ->useAttributeAsKey('name') @@ -166,6 +180,45 @@ ->end() ->end() ->end() + ->arrayNode('clickhouse') + ->normalizeKeys(false) + ->useAttributeAsKey('name') + ->arrayPrototype() + ->children() + ->scalarNode('dsn')->cannotBeEmpty()->end() + ->scalarNode('http_client')->cannotBeEmpty()->end() + ->scalarNode('database')->isRequired()->cannotBeEmpty()->end() + ->scalarNode('table')->isRequired()->cannotBeEmpty()->end() + ->end() + ->validate() + ->ifTrue(static fn ($v) => !isset($v['dsn']) && !isset($v['http_client'])) + ->thenInvalid('Either "dsn" or "http_client" must be configured.') + ->end() + ->end() + ->end() + ->arrayNode('meilisearch') + ->normalizeKeys(false) + ->useAttributeAsKey('name') + ->arrayPrototype() + ->children() + ->scalarNode('endpoint')->cannotBeEmpty()->end() + ->scalarNode('api_key')->cannotBeEmpty()->end() + ->scalarNode('index_name')->cannotBeEmpty()->end() + ->scalarNode('embedder')->end() + ->scalarNode('vector_field')->end() + ->scalarNode('dimensions')->end() + ->end() + ->end() + ->end() + ->arrayNode('memory') + ->normalizeKeys(false) + ->useAttributeAsKey('name') + ->arrayPrototype() + ->children() + ->scalarNode('distance')->cannotBeEmpty()->end() + ->end() + ->end() + ->end() ->arrayNode('mongodb') ->normalizeKeys(false) ->useAttributeAsKey('name') @@ -180,6 +233,24 @@ ->end() ->end() ->end() + ->arrayNode('neo4j') + ->normalizeKeys(false) + ->useAttributeAsKey('name') + ->arrayPrototype() + ->children() + ->scalarNode('endpoint')->cannotBeEmpty()->end() + ->scalarNode('username')->cannotBeEmpty()->end() + ->scalarNode('password')->cannotBeEmpty()->end() + ->scalarNode('database')->cannotBeEmpty()->end() + ->scalarNode('vector_index_name')->cannotBeEmpty()->end() + ->scalarNode('node_name')->cannotBeEmpty()->end() + ->scalarNode('vector_field')->end() + ->scalarNode('dimensions')->end() + ->scalarNode('distance')->end() + ->booleanNode('quantization')->end() + ->end() + ->end() + ->end() ->arrayNode('pinecone') ->normalizeKeys(false) ->useAttributeAsKey('name') @@ -194,6 +265,50 @@ ->end() ->end() ->end() + ->arrayNode('qdrant') + ->normalizeKeys(false) + ->useAttributeAsKey('name') + ->arrayPrototype() + ->children() + ->scalarNode('endpoint')->cannotBeEmpty()->end() + ->scalarNode('api_key')->cannotBeEmpty()->end() + ->scalarNode('collection_name')->cannotBeEmpty()->end() + ->scalarNode('dimensions')->end() + ->scalarNode('distance')->end() + ->end() + ->end() + ->end() + ->arrayNode('surreal_db') + ->normalizeKeys(false) + ->useAttributeAsKey('name') + ->arrayPrototype() + ->children() + ->scalarNode('endpoint')->cannotBeEmpty()->end() + ->scalarNode('username')->cannotBeEmpty()->end() + ->scalarNode('password')->cannotBeEmpty()->end() + ->scalarNode('namespace')->cannotBeEmpty()->end() + ->scalarNode('database')->cannotBeEmpty()->end() + ->scalarNode('table')->end() + ->scalarNode('vector_field')->end() + ->scalarNode('strategy')->end() + ->scalarNode('dimensions')->end() + ->booleanNode('namespaced_user')->end() + ->end() + ->end() + ->end() + ->arrayNode('typesense') + ->normalizeKeys(false) + ->useAttributeAsKey('name') + ->arrayPrototype() + ->children() + ->scalarNode('endpoint')->cannotBeEmpty()->end() + ->scalarNode('api_key')->isRequired()->end() + ->scalarNode('collection')->isRequired()->end() + ->scalarNode('vector_field')->end() + ->scalarNode('dimensions')->end() + ->end() + ->end() + ->end() ->end() ->end() ->arrayNode('indexer') diff --git a/src/ai-bundle/config/services.php b/src/ai-bundle/config/services.php index c23962654..1506bf310 100644 --- a/src/ai-bundle/config/services.php +++ b/src/ai-bundle/config/services.php @@ -26,6 +26,7 @@ use Symfony\AI\AiBundle\Security\EventListener\IsGrantedToolAttributeListener; use Symfony\AI\Platform\Bridge\Anthropic\Contract\AnthropicContract; use Symfony\AI\Platform\Bridge\Gemini\Contract\GeminiContract; +use Symfony\AI\Platform\Bridge\Ollama\Contract\OllamaContract; use Symfony\AI\Platform\Bridge\OpenAi\Whisper\AudioNormalizer; use Symfony\AI\Platform\Contract; use Symfony\AI\Platform\Contract\JsonSchema\DescriptionParser; @@ -44,6 +45,8 @@ ->factory([AnthropicContract::class, 'create']) ->set('ai.platform.contract.google', Contract::class) ->factory([GeminiContract::class, 'create']) + ->set('ai.platform.contract.ollama', Contract::class) + ->factory([OllamaContract::class, 'create']) // structured output ->set('ai.agent.response_format_factory', ResponseFormatFactory::class) ->args([ @@ -123,7 +126,7 @@ ]) ->tag('data_collector') ->set('ai.traceable_toolbox', TraceableToolbox::class) - ->decorate('ai.toolbox') + ->decorate('ai.toolbox', priority: -1) ->args([ service('.inner'), ]) diff --git a/src/ai-bundle/doc/index.rst b/src/ai-bundle/doc/index.rst index 846ddfdbb..fe409dffe 100644 --- a/src/ai-bundle/doc/index.rst +++ b/src/ai-bundle/doc/index.rst @@ -52,6 +52,8 @@ Configuration api_version: '%env(AZURE_GPT_VERSION)%' gemini: api_key: '%env(GEMINI_API_KEY)%' + ollama: + host_url: '%env(OLLAMA_HOST_URL)%' agent: rag: platform: 'ai.platform.azure.gpt_deployment' @@ -84,7 +86,7 @@ Configuration - 'Symfony\AI\Agent\Toolbox\Tool\Wikipedia' fault_tolerant_toolbox: false # Disables fault tolerant toolbox, default is true store: - # also azure_search, mongodb and pinecone are supported as store type + # also azure_search, meilisearch, memory, mongodb, pinecone, qdrant and surrealdb are supported as store type chroma_db: # multiple collections possible per type default: @@ -146,6 +148,9 @@ To use existing tools, you can register them as a service: $apiKey: '%env(TAVILY_API_KEY)%' Symfony\AI\Agent\Toolbox\Tool\Wikipedia: ~ Symfony\AI\Agent\Toolbox\Tool\YouTubeTranscriber: ~ + Symfony\AI\Agent\Toolbox\Tool\Firecrawl: + $endpoint: '%env(FIRECRAWL_ENDPOINT)%' + $apiKey: '%env(FIRECRAWL_API_KEY)%' Custom tools can be registered by using the ``#[AsTool]`` attribute:: diff --git a/src/ai-bundle/src/AiBundle.php b/src/ai-bundle/src/AiBundle.php index 046b6be2d..04c391f7c 100644 --- a/src/ai-bundle/src/AiBundle.php +++ b/src/ai-bundle/src/AiBundle.php @@ -30,8 +30,10 @@ use Symfony\AI\Platform\Bridge\Gemini\PlatformFactory as GeminiPlatformFactory; use Symfony\AI\Platform\Bridge\LmStudio\PlatformFactory as LmStudioPlatformFactory; use Symfony\AI\Platform\Bridge\Mistral\PlatformFactory as MistralPlatformFactory; +use Symfony\AI\Platform\Bridge\Ollama\PlatformFactory as OllamaPlatformFactory; use Symfony\AI\Platform\Bridge\OpenAi\PlatformFactory as OpenAiPlatformFactory; use Symfony\AI\Platform\Bridge\OpenRouter\PlatformFactory as OpenRouterPlatformFactory; +use Symfony\AI\Platform\Bridge\Cerebras\PlatformFactory as CerebrasPlatformFactory; use Symfony\AI\Platform\Model; use Symfony\AI\Platform\ModelClientInterface; use Symfony\AI\Platform\Platform; @@ -39,10 +41,18 @@ use Symfony\AI\Platform\ResultConverterInterface; use Symfony\AI\Store\Bridge\Azure\SearchStore as AzureSearchStore; use Symfony\AI\Store\Bridge\ChromaDb\Store as ChromaDbStore; +use Symfony\AI\Store\Bridge\ClickHouse\Store as ClickHouseStore; +use Symfony\AI\Store\Bridge\Meilisearch\Store as MeilisearchStore; use Symfony\AI\Store\Bridge\MongoDb\Store as MongoDbStore; +use Symfony\AI\Store\Bridge\Neo4j\Store as Neo4jStore; use Symfony\AI\Store\Bridge\Pinecone\Store as PineconeStore; +use Symfony\AI\Store\Bridge\Qdrant\Store as QdrantStore; +use Symfony\AI\Store\Bridge\SurrealDb\Store as SurrealDbStore; +use Symfony\AI\Store\Bridge\Typesense\Store as TypesenseStore; +use Symfony\AI\Store\CacheStore; use Symfony\AI\Store\Document\Vectorizer; use Symfony\AI\Store\Indexer; +use Symfony\AI\Store\InMemoryStore; use Symfony\AI\Store\StoreInterface; use Symfony\AI\Store\VectorStoreInterface; use Symfony\Component\Config\Definition\Configurator\DefinitionConfigurator; @@ -52,8 +62,10 @@ use Symfony\Component\DependencyInjection\Definition; use Symfony\Component\DependencyInjection\Loader\Configurator\ContainerConfigurator; use Symfony\Component\DependencyInjection\Reference; +use Symfony\Component\HttpClient\HttpClient; use Symfony\Component\HttpKernel\Bundle\AbstractBundle; use Symfony\Component\Security\Core\Authorization\AuthorizationCheckerInterface; +use Symfony\Contracts\HttpClient\HttpClientInterface; use function Symfony\Component\String\u; @@ -271,7 +283,7 @@ private function processPlatformConfig(string $type, array $platform, ContainerB if ('lmstudio' === $type) { $platformId = 'symfony_ai.platform.lmstudio'; $definition = (new Definition(Platform::class)) - ->setFactory(LmStudioPlatformFactory::class.'::create') + ->setFactory(LmStudioPlatformFactory::class.'::create') ->setLazy(true) ->addTag('proxy', ['interface' => PlatformInterface::class]) ->setArguments([ @@ -286,6 +298,42 @@ private function processPlatformConfig(string $type, array $platform, ContainerB return; } + if ('ollama' === $type) { + $platformId = 'ai.platform.ollama'; + $definition = (new Definition(Platform::class)) + ->setFactory(MistralPlatformFactory::class.'::create') + ->setFactory(OllamaPlatformFactory::class.'::create') + ->setLazy(true) + ->addTag('proxy', ['interface' => PlatformInterface::class]) + ->setArguments([ + $platform['host_url'], + new Reference('http_client', ContainerInterface::NULL_ON_INVALID_REFERENCE), + new Reference('ai.platform.contract.ollama'), + ]) + ->addTag('ai.platform'); + + $container->setDefinition($platformId, $definition); + + return; + } + + if ('cerebras' === $type && isset($platform['api_key'])) { + $platformId = 'ai.platform.cerebras'; + $definition = (new Definition(Platform::class)) + ->setFactory(CerebrasPlatformFactory::class.'::create') + ->setLazy(true) + ->addTag('proxy', ['interface' => PlatformInterface::class]) + ->setArguments([ + $platform['api_key'], + new Reference('http_client', ContainerInterface::NULL_ON_INVALID_REFERENCE), + ]) + ->addTag('ai.platform'); + + $container->setDefinition($platformId, $definition); + + return; + } + throw new InvalidArgumentException(\sprintf('Platform "%s" is not supported for configuration via bundle at this point.', $type)); } @@ -356,11 +404,9 @@ private function processAgentConfig(string $name, array $config, ContainerBuilde $container->setDefinition('ai.toolbox.'.$name, $toolboxDefinition); if ($config['fault_tolerant_toolbox']) { - $faultTolerantToolboxDefinition = (new Definition('ai.fault_tolerant_toolbox.'.$name)) - ->setClass(FaultTolerantToolbox::class) + $container->setDefinition('ai.fault_tolerant_toolbox.'.$name, new Definition(FaultTolerantToolbox::class)) ->setArguments([new Reference('.inner')]) ->setDecoratedService('ai.toolbox.'.$name); - $container->setDefinition('ai.fault_tolerant_toolbox.'.$name, $faultTolerantToolboxDefinition); } if ($container->getParameter('kernel.debug')) { @@ -379,6 +425,12 @@ private function processAgentConfig(string $name, array $config, ContainerBuilde $inputProcessors[] = new Reference('ai.tool.agent_processor.'.$name); $outputProcessors[] = new Reference('ai.tool.agent_processor.'.$name); } else { + if ($config['fault_tolerant_toolbox'] && !$container->hasDefinition('ai.fault_tolerant_toolbox')) { + $container->setDefinition('ai.fault_tolerant_toolbox', new Definition(FaultTolerantToolbox::class)) + ->setArguments([new Reference('.inner')]) + ->setDecoratedService('ai.toolbox'); + } + $inputProcessors[] = new Reference('ai.tool.agent_processor'); $outputProcessors[] = new Reference('ai.tool.agent_processor'); } @@ -438,6 +490,21 @@ private function processStoreConfig(string $type, array $stores, ContainerBuilde } } + if ('cache' === $type) { + foreach ($stores as $name => $store) { + $arguments = [ + new Reference($store['service']), + ]; + + $definition = new Definition(CacheStore::class); + $definition + ->addTag('ai.store') + ->setArguments($arguments); + + $container->setDefinition('ai.store.'.$type.'.'.$name, $definition); + } + } + if ('chroma_db' === $type) { foreach ($stores as $name => $store) { $definition = new Definition(ChromaDbStore::class); @@ -452,6 +519,77 @@ private function processStoreConfig(string $type, array $stores, ContainerBuilde } } + if ('clickhouse' === $type) { + foreach ($stores as $name => $store) { + if (isset($store['http_client'])) { + $httpClient = new Reference($store['http_client']); + } else { + $httpClient = new Definition(HttpClientInterface::class); + $httpClient + ->setFactory([HttpClient::class, 'createForBaseUri']) + ->setArguments([$store['dsn']]) + ; + } + + $definition = new Definition(ClickHouseStore::class); + $definition + ->setArguments([ + $httpClient, + $store['database'], + $store['table'], + ]) + ->addTag('ai.store') + ; + + $container->setDefinition('ai.store.'.$type.'.'.$name, $definition); + } + } + + if ('meilisearch' === $type) { + foreach ($stores as $name => $store) { + $arguments = [ + new Reference('http_client'), + $store['endpoint'], + $store['api_key'], + $store['index_name'], + ]; + + if (\array_key_exists('embedder', $store)) { + $arguments[4] = $store['embedder']; + } + + if (\array_key_exists('vector_field', $store)) { + $arguments[5] = $store['vector_field']; + } + + if (\array_key_exists('dimensions', $store)) { + $arguments[6] = $store['dimensions']; + } + + $definition = new Definition(MeilisearchStore::class); + $definition + ->addTag('ai.store') + ->setArguments($arguments); + + $container->setDefinition('ai.store.'.$type.'.'.$name, $definition); + } + } + + if ('memory' === $type) { + foreach ($stores as $name => $store) { + $arguments = [ + $store['distance'], + ]; + + $definition = new Definition(InMemoryStore::class); + $definition + ->addTag('ai.store') + ->setArguments($arguments); + + $container->setDefinition('ai.store.'.$type.'.'.$name, $definition); + } + } + if ('mongodb' === $type) { foreach ($stores as $name => $store) { $arguments = [ @@ -478,6 +616,43 @@ private function processStoreConfig(string $type, array $stores, ContainerBuilde } } + if ('neo4j' === $type) { + foreach ($stores as $name => $store) { + $arguments = [ + new Reference('http_client'), + $store['endpoint'], + $store['username'], + $store['password'], + $store['database'], + $store['vector_index_name'], + $store['node_name'], + ]; + + if (\array_key_exists('vector_field', $store)) { + $arguments[7] = $store['vector_field']; + } + + if (\array_key_exists('dimensions', $store)) { + $arguments[8] = $store['dimensions']; + } + + if (\array_key_exists('distance', $store)) { + $arguments[9] = $store['distance']; + } + + if (\array_key_exists('quantization', $store)) { + $arguments[10] = $store['quantization']; + } + + $definition = new Definition(Neo4jStore::class); + $definition + ->addTag('ai.store') + ->setArguments($arguments); + + $container->setDefinition('ai.store.'.$type.'.'.$name, $definition); + } + } + if ('pinecone' === $type) { foreach ($stores as $name => $store) { $arguments = [ @@ -501,6 +676,98 @@ private function processStoreConfig(string $type, array $stores, ContainerBuilde $container->setDefinition('ai.store.'.$type.'.'.$name, $definition); } } + + if ('qdrant' === $type) { + foreach ($stores as $name => $store) { + $arguments = [ + new Reference('http_client'), + $store['endpoint'], + $store['api_key'], + $store['collection_name'], + ]; + + if (\array_key_exists('dimensions', $store)) { + $arguments[4] = $store['dimensions']; + } + + if (\array_key_exists('distance', $store)) { + $arguments[5] = $store['distance']; + } + + $definition = new Definition(QdrantStore::class); + $definition + ->addTag('ai.store') + ->setArguments($arguments); + + $container->setDefinition('ai.store.'.$type.'.'.$name, $definition); + } + } + + if ('surreal_db' === $type) { + foreach ($stores as $name => $store) { + $arguments = [ + new Reference('http_client'), + $store['endpoint'], + $store['username'], + $store['password'], + $store['namespace'], + $store['database'], + ]; + + if (\array_key_exists('table', $store)) { + $arguments[6] = $store['table']; + } + + if (\array_key_exists('vector_field', $store)) { + $arguments[7] = $store['vector_field']; + } + + if (\array_key_exists('strategy', $store)) { + $arguments[8] = $store['strategy']; + } + + if (\array_key_exists('dimensions', $store)) { + $arguments[9] = $store['dimensions']; + } + + if (\array_key_exists('namespaced_user', $store)) { + $arguments[10] = $store['namespaced_user']; + } + + $definition = new Definition(SurrealDbStore::class); + $definition + ->addTag('ai.store') + ->setArguments($arguments); + + $container->setDefinition('ai.store.'.$type.'.'.$name, $definition); + } + } + + if ('typesense' === $type) { + foreach ($stores as $name => $store) { + $arguments = [ + new Reference('http_client'), + $store['endpoint'], + $store['api_key'], + $store['collection'], + ]; + + if (\array_key_exists('vector_field', $store)) { + $arguments[4] = $store['vector_field']; + } + + if (\array_key_exists('dimensions', $store)) { + $arguments[5] = $store['dimensions']; + } + + $definition = new Definition(TypesenseStore::class); + $definition + ->addTag('ai.store') + ->setArguments($arguments); + + $container->setDefinition('ai.store.'.$type.'.'.$name, $definition); + } + } } /** diff --git a/src/ai-bundle/templates/data_collector.html.twig b/src/ai-bundle/templates/data_collector.html.twig index c69384290..359189784 100644 --- a/src/ai-bundle/templates/data_collector.html.twig +++ b/src/ai-bundle/templates/data_collector.html.twig @@ -48,7 +48,7 @@
    {% for toolCall in toolCalls %}
  1. - {{ toolCall.name }}({{ toolCall.arguments|map((value, key) => "#{key}: #{value}")|join(', ') }}) + {{ toolCall.name }}({{ toolCall.arguments|map((value, key) => "#{key}: #{value|json_encode}")|join(', ') }}) (ID: {{ toolCall.id }})
  2. {% endfor %} diff --git a/src/ai-bundle/tests/DependencyInjection/AiBundleTest.php b/src/ai-bundle/tests/DependencyInjection/AiBundleTest.php index a8d65e10c..ba0b1993f 100644 --- a/src/ai-bundle/tests/DependencyInjection/AiBundleTest.php +++ b/src/ai-bundle/tests/DependencyInjection/AiBundleTest.php @@ -13,6 +13,7 @@ use PHPUnit\Framework\Attributes\CoversClass; use PHPUnit\Framework\Attributes\DoesNotPerformAssertions; +use PHPUnit\Framework\Attributes\TestWith; use PHPUnit\Framework\Attributes\UsesClass; use PHPUnit\Framework\TestCase; use Symfony\AI\AiBundle\AiBundle; @@ -29,6 +30,46 @@ public function testExtensionLoadDoesNotThrow() $this->buildContainer($this->getFullConfig()); } + #[TestWith([true], 'enabled')] + #[TestWith([false], 'disabled')] + public function testFaultTolerantAgentSpecificToolbox(bool $enabled) + { + $container = $this->buildContainer([ + 'ai' => [ + 'agent' => [ + 'my_agent' => [ + 'model' => ['class' => 'Symfony\AI\Platform\Bridge\OpenAi\Gpt'], + 'tools' => [ + ['service' => 'some_service', 'description' => 'Some tool'], + ], + 'fault_tolerant_toolbox' => $enabled, + ], + ], + ], + ]); + + $this->assertSame($enabled, $container->hasDefinition('ai.fault_tolerant_toolbox.my_agent')); + } + + #[TestWith([true], 'enabled')] + #[TestWith([false], 'disabled')] + public function testFaultTolerantDefaultToolbox(bool $enabled) + { + $container = $this->buildContainer([ + 'ai' => [ + 'agent' => [ + 'my_agent' => [ + 'model' => ['class' => 'Symfony\AI\Platform\Bridge\OpenAi\Gpt'], + 'tools' => true, + 'fault_tolerant_toolbox' => $enabled, + ], + ], + ], + ]); + + $this->assertSame($enabled, $container->hasDefinition('ai.fault_tolerant_toolbox')); + } + public function testAgentsCanBeRegisteredAsTools() { $container = $this->buildContainer([ @@ -117,6 +158,9 @@ private function getFullConfig(): array 'lmstudio' => [ 'host_url' => 'http://127.0.0.1:1234', ], + 'ollama' => [ + 'host_url' => 'http://127.0.0.1:11434', + ], ], 'agent' => [ 'my_chat_agent' => [ @@ -157,11 +201,38 @@ private function getFullConfig(): array 'vector_field' => 'contentVector', ], ], + 'cache' => [ + 'my_cache_store' => [ + 'service' => 'cache.system', + ], + ], 'chroma_db' => [ 'my_chroma_store' => [ 'collection' => 'my_collection', ], ], + 'clickhouse' => [ + 'my_clickhouse_store' => [ + 'dsn' => 'http://foo:bar@1.2.3.4:9999', + 'database' => 'my_db', + 'table' => 'my_table', + ], + ], + 'meilisearch' => [ + 'my_meilisearch_store' => [ + 'endpoint' => 'http://127.0.0.1:7700', + 'api_key' => 'foo', + 'index_name' => 'test', + 'embedder' => 'default', + 'vector_field' => '_vectors', + 'dimensions' => 768, + ], + ], + 'memory' => [ + 'my_memory_store' => [ + 'distance' => 'cosine', + ], + ], 'mongodb' => [ 'my_mongo_store' => [ 'database' => 'my_db', @@ -171,6 +242,20 @@ private function getFullConfig(): array 'bulk_write' => true, ], ], + 'neo4j' => [ + 'my_neo4j_store' => [ + 'endpoint' => 'http://127.0.0.1:8000', + 'username' => 'test', + 'password' => 'test', + 'database' => 'foo', + 'vector_index_name' => 'test', + 'node_name' => 'foo', + 'vector_field' => '_vectors', + 'dimensions' => 768, + 'distance' => 'cosine', + 'quantization' => true, + ], + ], 'pinecone' => [ 'my_pinecone_store' => [ 'namespace' => 'my_namespace', @@ -178,6 +263,38 @@ private function getFullConfig(): array 'top_k' => 10, ], ], + 'qdrant' => [ + 'my_qdrant_store' => [ + 'endpoint' => 'http://127.0.0.1:8000', + 'api_key' => 'test', + 'collection_name' => 'foo', + 'dimensions' => 768, + 'distance' => 'Cosine', + ], + ], + 'surreal_db' => [ + 'my_surreal_db_store' => [ + 'endpoint' => 'http://127.0.0.1:8000', + 'username' => 'test', + 'password' => 'test', + 'namespace' => 'foo', + 'database' => 'bar', + 'table' => 'bar', + 'vector_field' => '_vectors', + 'strategy' => 'cosine', + 'dimensions' => 768, + 'namespaced_user' => true, + ], + ], + 'typesense' => [ + 'my_typesense_store' => [ + 'endpoint' => 'http://localhost:8108', + 'api_key' => 'foo', + 'collection' => 'my_collection', + 'vector_field' => 'vector', + 'dimensions' => 768, + ], + ], ], 'indexer' => [ 'my_text_indexer' => [ diff --git a/src/ai-bundle/tests/Security/IsGrantedToolAttributeListenerTest.php b/src/ai-bundle/tests/Security/IsGrantedToolAttributeListenerTest.php index 0f4d00ea4..1bf74f8b8 100644 --- a/src/ai-bundle/tests/Security/IsGrantedToolAttributeListenerTest.php +++ b/src/ai-bundle/tests/Security/IsGrantedToolAttributeListenerTest.php @@ -56,8 +56,8 @@ public function testItWillThrowWhenNotGranted(object $tool, Tool $metadata) { $this->authChecker->expects($this->once())->method('isGranted')->willReturn(false); - self::expectException(AccessDeniedException::class); - self::expectExceptionMessage(\sprintf('No access to %s tool.', $metadata->name)); + $this->expectException(AccessDeniedException::class); + $this->expectExceptionMessage(\sprintf('No access to %s tool.', $metadata->name)); $this->dispatcher->dispatch(new ToolCallArgumentsResolved($tool, $metadata, [])); } diff --git a/src/mcp-sdk/composer.json b/src/mcp-sdk/composer.json index b692e3df6..e515ae898 100644 --- a/src/mcp-sdk/composer.json +++ b/src/mcp-sdk/composer.json @@ -22,7 +22,6 @@ "phpstan/phpstan": "^2.1", "phpunit/phpunit": "^11.5", "symfony/console": "^6.4 || ^7.0", - "rector/rector": "^2.0", "psr/cache": "^3.0" }, "suggest": { diff --git a/src/platform/CHANGELOG.md b/src/platform/CHANGELOG.md index 34eb978ed..1570ad0f6 100644 --- a/src/platform/CHANGELOG.md +++ b/src/platform/CHANGELOG.md @@ -21,6 +21,7 @@ CHANGELOG - HuggingFace (extensive model support with multiple tasks) - TransformersPHP (local PHP-based transformer models) - LM Studio (local model hosting) + - Cerebras (language models like Llama 4, Qwen 3, and more) * Add comprehensive message system with role-based messaging: - `UserMessage` for user inputs with multi-modal content - `SystemMessage` for system instructions diff --git a/src/platform/doc/index.rst b/src/platform/doc/index.rst index 3bc300796..98dc0fb34 100644 --- a/src/platform/doc/index.rst +++ b/src/platform/doc/index.rst @@ -289,7 +289,7 @@ For unit or integration testing, you can use the `InMemoryPlatform`, which imple It supports returning either: - A fixed string result -- A callable that dynamically returns a response based on the model, input, and options:: +- A callable that dynamically returns a simple string or any ``ResultInterface`` based on the model, input, and options:: use Symfony\AI\Platform\InMemoryPlatform; use Symfony\AI\Platform\Model; @@ -300,8 +300,41 @@ It supports returning either: echo $result->asText(); // "Fake result" +**Dynamic Text Results**:: -Internally, it uses `InMemoryRawResult` to simulate the behavior of real API responses and support `ResultPromise`. + $platform = new InMemoryPlatform( + fn($model, $input, $options) => "Echo: {$input}" + ); + + $result = $platform->invoke(new Model('test'), 'Hello AI'); + echo $result->asText(); // "Echo: Hello AI" + +**Vector Results**:: + + use Symfony\AI\Platform\Result\VectorResult; + + $platform = new InMemoryPlatform( + fn() => new VectorResult(new Vector([0.1, 0.2, 0.3, 0.4])) + ); + + $result = $platform->invoke(new Model('test'), 'vectorize this text'); + $vectors = $result->asVectors(); // Returns Vector object with [0.1, 0.2, 0.3, 0.4] + +**Binary Results**:: + + use Symfony\AI\Platform\Result\BinaryResult; + + $platform = new InMemoryPlatform( + fn() => new BinaryResult('fake-pdf-content', 'application/pdf') + ); + + $result = $platform->invoke(new Model('test'), 'generate PDF document'); + $binary = $result->asBinary(); // Returns Binary object with content and MIME type + + +**Raw Results** + +The platform automatically uses the ``getRawResult()`` from any ``ResultInterface`` returned by closures. For string results, it creates an ``InMemoryRawResult`` to simulate real API response metadata. This allows fast and isolated testing of AI-powered features without relying on live providers or HTTP requests. @@ -313,6 +346,8 @@ This allows fast and isolated testing of AI-powered features without relying on * `Parallel GPT Calls`_ * `Parallel Embeddings Calls`_ +* `Cerebras Chat`_ +* `Cerebras Streaming`_ .. note:: @@ -359,3 +394,5 @@ This allows fast and isolated testing of AI-powered features without relying on .. _`Parallel Embeddings Calls`: https://github.com/symfony/ai/blob/main/examples/misc/parallel-embeddings.php .. _`LM Studio`: https://lmstudio.ai/ .. _`LM Studio Catalog`: https://lmstudio.ai/models +.. _`Cerebras Chat`: https://github.com/symfony/ai/blob/main/examples/cerebras/chat.php +.. _`Cerebras Streaming`: https://github.com/symfony/ai/blob/main/examples/cerebras/stream.php diff --git a/src/platform/src/Bridge/Anthropic/Claude.php b/src/platform/src/Bridge/Anthropic/Claude.php index 66d92b26b..56d0cc18c 100644 --- a/src/platform/src/Bridge/Anthropic/Claude.php +++ b/src/platform/src/Bridge/Anthropic/Claude.php @@ -29,6 +29,7 @@ class Claude extends Model public const OPUS_3 = 'claude-3-opus-20240229'; public const OPUS_4 = 'claude-opus-4-20250514'; public const OPUS_4_0 = 'claude-opus-4-0'; + public const OPUS_4_1 = 'claude-opus-4-1'; /** * @param array $options The default options for the model usage diff --git a/src/platform/src/Bridge/Anthropic/Contract/AssistantMessageNormalizer.php b/src/platform/src/Bridge/Anthropic/Contract/AssistantMessageNormalizer.php index c03e034e7..09c256a72 100644 --- a/src/platform/src/Bridge/Anthropic/Contract/AssistantMessageNormalizer.php +++ b/src/platform/src/Bridge/Anthropic/Contract/AssistantMessageNormalizer.php @@ -26,22 +26,12 @@ final class AssistantMessageNormalizer extends ModelContractNormalizer implement { use NormalizerAwareTrait; - protected function supportedDataClass(): string - { - return AssistantMessage::class; - } - - protected function supportsModel(Model $model): bool - { - return $model instanceof Claude; - } - /** * @param AssistantMessage $data * * @return array{ * role: 'assistant', - * content: list 'assistant', - 'content' => array_map(static function (ToolCall $toolCall) { + 'content' => $data->hasToolCalls() ? array_map(static function (ToolCall $toolCall) { return [ 'type' => 'tool_use', 'id' => $toolCall->id, 'name' => $toolCall->name, 'input' => [] !== $toolCall->arguments ? $toolCall->arguments : new \stdClass(), ]; - }, $data->toolCalls), + }, $data->toolCalls) : $data->content, ]; } + + protected function supportedDataClass(): string + { + return AssistantMessage::class; + } + + protected function supportsModel(Model $model): bool + { + return $model instanceof Claude; + } } diff --git a/src/platform/src/Bridge/Cerebras/Model.php b/src/platform/src/Bridge/Cerebras/Model.php new file mode 100644 index 000000000..3df43a82f --- /dev/null +++ b/src/platform/src/Bridge/Cerebras/Model.php @@ -0,0 +1,39 @@ + + */ +final class Model extends BaseModel +{ + public const LLAMA_4_SCOUT_17B_16E_INSTRUCT = 'llama-4-scout-17b-16e-instruct'; + public const LLAMA3_1_8B = 'llama3.1-8b'; + public const LLAMA_3_3_70B = 'llama-3.3-70b'; + public const LLAMA_4_MAVERICK_17B_128E_INSTRUCT = 'llama-4-maverick-17b-128e-instruct'; + public const QWEN_3_32B = 'qwen-3-32b'; + public const QWEN_3_235B_A22B_INSTRUCT_2507 = 'qwen-3-235b-a22b-instruct-2507'; + public const QWEN_3_235B_A22B_THINKING_2507 = 'qwen-3-235b-a22b-thinking-2507'; + public const QWEN_3_CODER_480B = 'qwen-3-coder-480b'; + public const GPT_OSS_120B = 'gpt-oss-120b'; + + public const CAPABILITIES = [ + Capability::INPUT_MESSAGES, + Capability::OUTPUT_TEXT, + Capability::OUTPUT_STREAMING, + ]; + + /** + * @see https://inference-docs.cerebras.ai/api-reference/chat-completions for details like options + */ + public function __construct( + string $name = self::LLAMA3_1_8B, + array $capabilities = self::CAPABILITIES, + array $options = [], + ) { + parent::__construct($name, $capabilities, $options); + } +} diff --git a/src/platform/src/Bridge/Cerebras/ModelClient.php b/src/platform/src/Bridge/Cerebras/ModelClient.php new file mode 100644 index 000000000..2ff201b5b --- /dev/null +++ b/src/platform/src/Bridge/Cerebras/ModelClient.php @@ -0,0 +1,64 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\Cerebras; + +use Symfony\AI\Platform\Exception\InvalidArgumentException; +use Symfony\AI\Platform\Model as BaseModel; +use Symfony\AI\Platform\ModelClientInterface; +use Symfony\AI\Platform\Result\RawHttpResult; +use Symfony\Component\HttpClient\EventSourceHttpClient; +use Symfony\Contracts\HttpClient\HttpClientInterface; + +/** + * @author Junaid Farooq + */ +final readonly class ModelClient implements ModelClientInterface +{ + private EventSourceHttpClient $httpClient; + + public function __construct( + HttpClientInterface $httpClient, + #[\SensitiveParameter] private string $apiKey, + ) { + if ('' === $apiKey) { + throw new InvalidArgumentException('The API key must not be empty.'); + } + + if (!str_starts_with($apiKey, 'csk-')) { + throw new InvalidArgumentException('The API key must start with "csk-".'); + } + + $this->httpClient = $httpClient instanceof EventSourceHttpClient ? $httpClient : new EventSourceHttpClient($httpClient); + } + + public function supports(BaseModel $model): bool + { + return $model instanceof Model; + } + + public function request(BaseModel $model, array|string $payload, array $options = []): RawHttpResult + { + return new RawHttpResult( + $this->httpClient->request( + 'POST', 'https://api.cerebras.ai/v1/chat/completions', + [ + 'headers' => [ + 'Content-Type' => 'application/json', + 'Authorization' => sprintf('Bearer %s', $this->apiKey), + ], + 'json' => \is_array($payload) ? array_merge($payload, $options) : $payload, + ] + ) + ); + } +} + diff --git a/src/platform/src/Bridge/Cerebras/PlatformFactory.php b/src/platform/src/Bridge/Cerebras/PlatformFactory.php new file mode 100644 index 000000000..27d277c30 --- /dev/null +++ b/src/platform/src/Bridge/Cerebras/PlatformFactory.php @@ -0,0 +1,34 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\Cerebras; + +use Symfony\AI\Platform\Platform; +use Symfony\Component\HttpClient\EventSourceHttpClient; +use Symfony\Contracts\HttpClient\HttpClientInterface; + +/** + * @author Junaid Farooq + */ +final readonly class PlatformFactory +{ + public static function create( + #[\SensitiveParameter] string $apiKey, + ?HttpClientInterface $httpClient = null, + ): Platform { + $httpClient = $httpClient instanceof EventSourceHttpClient ? $httpClient : new EventSourceHttpClient($httpClient); + + return new Platform( + [new ModelClient($httpClient, $apiKey)], + [new ResultConverter()], + ); + } +} diff --git a/src/platform/src/Bridge/Cerebras/ResultConverter.php b/src/platform/src/Bridge/Cerebras/ResultConverter.php new file mode 100644 index 000000000..d6310ba48 --- /dev/null +++ b/src/platform/src/Bridge/Cerebras/ResultConverter.php @@ -0,0 +1,77 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\Cerebras; + +use Symfony\AI\Platform\Exception\RuntimeException; +use Symfony\AI\Platform\Model as BaseModel; +use Symfony\AI\Platform\Result\RawHttpResult; +use Symfony\AI\Platform\Result\RawResultInterface; +use Symfony\AI\Platform\Result\ResultInterface; +use Symfony\AI\Platform\Result\StreamResult; +use Symfony\AI\Platform\Result\TextResult; +use Symfony\AI\Platform\ResultConverterInterface; +use Symfony\Component\HttpClient\Chunk\ServerSentEvent; +use Symfony\Component\HttpClient\EventSourceHttpClient; +use Symfony\Component\HttpClient\Exception\JsonException; +use Symfony\Contracts\HttpClient\ResponseInterface as HttpResponse; + +/** + * @author Junaid Farooq + */ +final readonly class ResultConverter implements ResultConverterInterface +{ + public function supports(BaseModel $model): bool + { + return $model instanceof Model; + } + + public function convert(RawHttpResult|RawResultInterface $result, array $options = []): ResultInterface + { + if ($options['stream'] ?? false) { + return new StreamResult($this->convertStream($result->getObject())); + } + + $data = $result->getData(); + + if (!isset($data['choices'][0]['message']['content'])) { + if (isset($data['type'], $data['message']) && str_ends_with($data['type'], 'error')) { + throw new RuntimeException(sprintf('Cerebras API error: %s', $data['message'])); + } + + throw new RuntimeException('Response does not contain output.'); + } + + return new TextResult($data['choices'][0]['message']['content']); + } + + private function convertStream(HttpResponse $result): \Generator + { + foreach ((new EventSourceHttpClient())->stream($result) as $chunk) { + if (!$chunk instanceof ServerSentEvent || '[DONE]' === $chunk->getData()) { + continue; + } + + try { + $data = $chunk->getArrayData(); + } catch (JsonException) { + continue; + } + + if (!isset($data['choices'][0]['delta']['content'])) { + continue; + } + + yield $data['choices'][0]['delta']['content']; + } + } +} + diff --git a/src/platform/src/Bridge/Ollama/Contract/AssistantMessageNormalizer.php b/src/platform/src/Bridge/Ollama/Contract/AssistantMessageNormalizer.php index c6d919c4a..736fad0f8 100644 --- a/src/platform/src/Bridge/Ollama/Contract/AssistantMessageNormalizer.php +++ b/src/platform/src/Bridge/Ollama/Contract/AssistantMessageNormalizer.php @@ -42,6 +42,7 @@ protected function supportsModel(Model $model): bool * * @return array{ * role: Role::Assistant, + * content: string, * tool_calls: list Role::Assistant, + 'content' => $data->content ?? '', 'tool_calls' => array_values(array_map(function (ToolCall $message): array { return [ 'type' => 'function', diff --git a/src/platform/src/Bridge/Ollama/Ollama.php b/src/platform/src/Bridge/Ollama/Ollama.php index bd0e16118..c31eef3de 100644 --- a/src/platform/src/Bridge/Ollama/Ollama.php +++ b/src/platform/src/Bridge/Ollama/Ollama.php @@ -37,11 +37,15 @@ class Ollama extends Model public const QWEN = 'qwen'; public const QWEN_2 = 'qwen2'; public const LLAMA_2 = 'llama2'; + public const NOMIC_EMBED_TEXT = 'nomic-embed-text'; + public const BGE_M3 = 'bge-m3'; + public const ALL_MINILM = 'all-minilm'; private const TOOL_PATTERNS = [ '/./' => [ Capability::INPUT_MESSAGES, Capability::OUTPUT_TEXT, + Capability::OUTPUT_STRUCTURED, ], '/^llama\D*3(\D*\d+)/' => [ Capability::TOOL_CALLING, @@ -52,6 +56,9 @@ class Ollama extends Model '/^(deepseek|mistral)/' => [ Capability::TOOL_CALLING, ], + '/^(nomic|bge|all-minilm).*/' => [ + Capability::INPUT_MULTIPLE, + ], ]; /** diff --git a/src/platform/src/Bridge/Ollama/OllamaClient.php b/src/platform/src/Bridge/Ollama/OllamaClient.php new file mode 100644 index 000000000..f1220ed79 --- /dev/null +++ b/src/platform/src/Bridge/Ollama/OllamaClient.php @@ -0,0 +1,90 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Bridge\Ollama; + +use Symfony\AI\Platform\Exception\InvalidArgumentException; +use Symfony\AI\Platform\Model; +use Symfony\AI\Platform\ModelClientInterface; +use Symfony\AI\Platform\Result\RawHttpResult; +use Symfony\Contracts\HttpClient\HttpClientInterface; + +/** + * @author Christopher Hertel + */ +final readonly class OllamaClient implements ModelClientInterface +{ + public function __construct( + private HttpClientInterface $httpClient, + private string $hostUrl, + ) { + } + + public function supports(Model $model): bool + { + return $model instanceof Ollama; + } + + public function request(Model $model, array|string $payload, array $options = []): RawHttpResult + { + $response = $this->httpClient->request('POST', \sprintf('%s/api/show', $this->hostUrl), [ + 'json' => [ + 'model' => $model->getName(), + ], + ]); + + $capabilities = $response->toArray()['capabilities'] ?? null; + + if (null === $capabilities) { + throw new InvalidArgumentException('The model information could not be retrieved from the Ollama API. Your Ollama server might be too old. Try upgrade it.'); + } + + return match (true) { + \in_array('completion', $capabilities, true) => $this->doCompletionRequest($payload, $options), + \in_array('embedding', $capabilities, true) => $this->doEmbeddingsRequest($model, $payload, $options), + default => throw new InvalidArgumentException(\sprintf('Unsupported model "%s": "%s".', $model::class, $model->getName())), + }; + } + + /** + * @param array $payload + * @param array $options + */ + private function doCompletionRequest(array|string $payload, array $options = []): RawHttpResult + { + // Revert Ollama's default streaming behavior + $options['stream'] ??= false; + + if (\array_key_exists('response_format', $options) && \array_key_exists('json_schema', $options['response_format'])) { + $options['format'] = $options['response_format']['json_schema']['schema']; + unset($options['response_format']); + } + + return new RawHttpResult($this->httpClient->request('POST', \sprintf('%s/api/chat', $this->hostUrl), [ + 'headers' => ['Content-Type' => 'application/json'], + 'json' => array_merge($options, $payload), + ])); + } + + /** + * @param array $payload + * @param array $options + */ + private function doEmbeddingsRequest(Model $model, array|string $payload, array $options = []): RawHttpResult + { + return new RawHttpResult($this->httpClient->request('POST', \sprintf('%s/api/embed', $this->hostUrl), [ + 'json' => array_merge($options, [ + 'model' => $model->getName(), + 'input' => $payload, + ]), + ])); + } +} diff --git a/src/platform/src/Bridge/Ollama/OllamaModelClient.php b/src/platform/src/Bridge/Ollama/OllamaModelClient.php deleted file mode 100644 index 03c651466..000000000 --- a/src/platform/src/Bridge/Ollama/OllamaModelClient.php +++ /dev/null @@ -1,45 +0,0 @@ - - * - * For the full copyright and license information, please view the LICENSE - * file that was distributed with this source code. - */ - -namespace Symfony\AI\Platform\Bridge\Ollama; - -use Symfony\AI\Platform\Model; -use Symfony\AI\Platform\ModelClientInterface; -use Symfony\AI\Platform\Result\RawHttpResult; -use Symfony\Contracts\HttpClient\HttpClientInterface; - -/** - * @author Christopher Hertel - */ -final readonly class OllamaModelClient implements ModelClientInterface -{ - public function __construct( - private HttpClientInterface $httpClient, - private string $hostUrl, - ) { - } - - public function supports(Model $model): bool - { - return $model instanceof Ollama; - } - - public function request(Model $model, array|string $payload, array $options = []): RawHttpResult - { - // Revert Ollama's default streaming behavior - $options['stream'] ??= false; - - return new RawHttpResult($this->httpClient->request('POST', \sprintf('%s/api/chat', $this->hostUrl), [ - 'headers' => ['Content-Type' => 'application/json'], - 'json' => array_merge($options, $payload), - ])); - } -} diff --git a/src/platform/src/Bridge/Ollama/OllamaResultConverter.php b/src/platform/src/Bridge/Ollama/OllamaResultConverter.php index edbf925e5..30e40c8db 100644 --- a/src/platform/src/Bridge/Ollama/OllamaResultConverter.php +++ b/src/platform/src/Bridge/Ollama/OllamaResultConverter.php @@ -18,7 +18,9 @@ use Symfony\AI\Platform\Result\TextResult; use Symfony\AI\Platform\Result\ToolCall; use Symfony\AI\Platform\Result\ToolCallResult; +use Symfony\AI\Platform\Result\VectorResult; use Symfony\AI\Platform\ResultConverterInterface; +use Symfony\AI\Platform\Vector\Vector; /** * @author Christopher Hertel @@ -34,6 +36,16 @@ public function convert(RawResultInterface $result, array $options = []): Result { $data = $result->getData(); + return \array_key_exists('embeddings', $data) + ? $this->doConvertEmbeddings($data) + : $this->doConvertCompletion($data); + } + + /** + * @param array $data + */ + public function doConvertCompletion(array $data): ResultInterface + { if (!isset($data['message'])) { throw new RuntimeException('Response does not contain message.'); } @@ -54,4 +66,21 @@ public function convert(RawResultInterface $result, array $options = []): Result return new TextResult($data['message']['content']); } + + /** + * @param array $data + */ + public function doConvertEmbeddings(array $data): ResultInterface + { + if ([] === $data['embeddings']) { + throw new RuntimeException('Response does not contain embeddings.'); + } + + return new VectorResult( + ...array_map( + static fn (array $embedding): Vector => new Vector($embedding), + $data['embeddings'], + ), + ); + } } diff --git a/src/platform/src/Bridge/Ollama/PlatformFactory.php b/src/platform/src/Bridge/Ollama/PlatformFactory.php index fcca081a0..af9b490ba 100644 --- a/src/platform/src/Bridge/Ollama/PlatformFactory.php +++ b/src/platform/src/Bridge/Ollama/PlatformFactory.php @@ -29,6 +29,10 @@ public static function create( ): Platform { $httpClient = $httpClient instanceof EventSourceHttpClient ? $httpClient : new EventSourceHttpClient($httpClient); - return new Platform([new OllamaModelClient($httpClient, $hostUrl)], [new OllamaResultConverter()], $contract ?? OllamaContract::create()); + return new Platform( + [new OllamaClient($httpClient, $hostUrl)], + [new OllamaResultConverter()], + $contract ?? OllamaContract::create() + ); } } diff --git a/src/platform/src/InMemoryPlatform.php b/src/platform/src/InMemoryPlatform.php index 7d5b6ce49..7d1ee744f 100644 --- a/src/platform/src/InMemoryPlatform.php +++ b/src/platform/src/InMemoryPlatform.php @@ -12,6 +12,7 @@ namespace Symfony\AI\Platform; use Symfony\AI\Platform\Result\InMemoryRawResult; +use Symfony\AI\Platform\Result\ResultInterface; use Symfony\AI\Platform\Result\ResultPromise; use Symfony\AI\Platform\Result\TextResult; @@ -26,7 +27,7 @@ class InMemoryPlatform implements PlatformInterface { /** * The mock result can be a string or a callable that returns a string. - * If it's a closure, it receives the model, input, and optionally options as parameters like a real platform call. + * If it's a closure, it receives the model, input, and optionally options as parameters like a real platform call. */ public function __construct(private readonly \Closure|string $mockResult) { @@ -34,19 +35,28 @@ public function __construct(private readonly \Closure|string $mockResult) public function invoke(Model $model, array|string|object $input, array $options = []): ResultPromise { - $resultText = $this->mockResult instanceof \Closure - ? ($this->mockResult)($model, $input, $options) - : $this->mockResult; - - $textResult = new TextResult($resultText); - - return new ResultPromise( - static fn () => $textResult, - rawResult: new InMemoryRawResult( - ['text' => $resultText], - (object) ['text' => $resultText], - ), - options: $options + $result = \is_string($this->mockResult) ? $this->mockResult : ($this->mockResult)($model, $input, $options); + + if ($result instanceof ResultInterface) { + return $this->createPromise($result, $options); + } + + return $this->createPromise(new TextResult($result), $options); + } + + /** + * Creates a ResultPromise from a ResultInterface. + * + * @param ResultInterface $result The result to wrap in a promise + * @param array $options Additional options for the promise + */ + private function createPromise(ResultInterface $result, array $options): ResultPromise + { + $rawResult = $result->getRawResult() ?? new InMemoryRawResult( + ['text' => $result->getContent()], + (object) ['text' => $result->getContent()], ); + + return new ResultPromise(static fn () => $result, $rawResult, $options); } } diff --git a/src/platform/tests/Bridge/Anthropic/Contract/AssistantMessageNormalizerTest.php b/src/platform/tests/Bridge/Anthropic/Contract/AssistantMessageNormalizerTest.php new file mode 100644 index 000000000..72ba32b29 --- /dev/null +++ b/src/platform/tests/Bridge/Anthropic/Contract/AssistantMessageNormalizerTest.php @@ -0,0 +1,113 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Tests\Bridge\Anthropic\Contract; + +use PHPUnit\Framework\Attributes\CoversClass; +use PHPUnit\Framework\Attributes\DataProvider; +use PHPUnit\Framework\Attributes\Small; +use PHPUnit\Framework\Attributes\UsesClass; +use PHPUnit\Framework\TestCase; +use Symfony\AI\Platform\Bridge\Anthropic\Claude; +use Symfony\AI\Platform\Bridge\Anthropic\Contract\AssistantMessageNormalizer; +use Symfony\AI\Platform\Contract; +use Symfony\AI\Platform\Message\AssistantMessage; +use Symfony\AI\Platform\Model; +use Symfony\AI\Platform\Result\ToolCall; + +#[Small] +#[CoversClass(AssistantMessageNormalizer::class)] +#[UsesClass(Claude::class)] +#[UsesClass(AssistantMessage::class)] +#[UsesClass(Model::class)] +#[UsesClass(ToolCall::class)] +final class AssistantMessageNormalizerTest extends TestCase +{ + public function testSupportsNormalization() + { + $normalizer = new AssistantMessageNormalizer(); + + $this->assertTrue($normalizer->supportsNormalization(new AssistantMessage('Hello'), context: [ + Contract::CONTEXT_MODEL => new Claude(), + ])); + $this->assertFalse($normalizer->supportsNormalization('not an assistant message')); + } + + public function testGetSupportedTypes() + { + $normalizer = new AssistantMessageNormalizer(); + + $this->assertSame([AssistantMessage::class => true], $normalizer->getSupportedTypes(null)); + } + + #[DataProvider('normalizeDataProvider')] + public function testNormalize(AssistantMessage $message, array $expectedOutput) + { + $normalizer = new AssistantMessageNormalizer(); + + $normalized = $normalizer->normalize($message); + + $this->assertEquals($expectedOutput, $normalized); + } + + /** + * @return iterable|\stdClass + * }> + * } + * }> + */ + public static function normalizeDataProvider(): iterable + { + yield 'assistant message' => [ + new AssistantMessage('Great to meet you. What would you like to know?'), + [ + 'role' => 'assistant', + 'content' => 'Great to meet you. What would you like to know?', + ], + ]; + yield 'function call' => [ + new AssistantMessage(toolCalls: [new ToolCall('id1', 'name1', ['arg1' => '123'])]), + [ + 'role' => 'assistant', + 'content' => [ + [ + 'type' => 'tool_use', + 'id' => 'id1', + 'name' => 'name1', + 'input' => ['arg1' => '123'], + ], + ], + ], + ]; + yield 'function call without parameters' => [ + new AssistantMessage(toolCalls: [new ToolCall('id1', 'name1')]), + [ + 'role' => 'assistant', + 'content' => [ + [ + 'type' => 'tool_use', + 'id' => 'id1', + 'name' => 'name1', + 'input' => new \stdClass(), + ], + ], + ], + ]; + } +} diff --git a/src/platform/tests/Bridge/Cerebras/ModelClientTest.php b/src/platform/tests/Bridge/Cerebras/ModelClientTest.php new file mode 100644 index 000000000..14e97762b --- /dev/null +++ b/src/platform/tests/Bridge/Cerebras/ModelClientTest.php @@ -0,0 +1,95 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Tests\Bridge\Cerebras; + +use PHPUnit\Framework\Attributes\CoversClass; +use PHPUnit\Framework\Attributes\Small; +use PHPUnit\Framework\Attributes\TestWith; +use PHPUnit\Framework\Attributes\UsesClass; +use PHPUnit\Framework\TestCase; +use Symfony\AI\Platform\Bridge\Cerebras\Model; +use Symfony\AI\Platform\Bridge\Cerebras\ModelClient; +use Symfony\AI\Platform\Exception\InvalidArgumentException; +use Symfony\Component\HttpClient\MockHttpClient; +use Symfony\Component\HttpClient\Response\JsonMockResponse; + +/** + * @author Junaid Farooq + */ +#[CoversClass(ModelClient::class)] +#[UsesClass(Model::class)] +#[Small] +class ModelClientTest extends TestCase +{ + public function testItDoesNotAllowAnEmptyKey() + { + $this->expectException(InvalidArgumentException::class); + $this->expectExceptionMessage('The API key must not be empty.'); + + new ModelClient(new MockHttpClient(), ''); + } + + #[TestWith(['api-key-without-prefix'])] + #[TestWith(['pk-api-key'])] + #[TestWith(['SK-api-key'])] + #[TestWith(['skapikey'])] + #[TestWith(['sk api-key'])] + #[TestWith(['sk'])] + public function testItVerifiesIfTheKeyStartsWithCsk(string $invalidApiKey) + { + $this->expectException(InvalidArgumentException::class); + $this->expectExceptionMessage('The API key must start with "csk-".'); + + new ModelClient(new MockHttpClient(), $invalidApiKey); + } + + public function testItSupportsTheCorrectModel() + { + $client = new ModelClient(new MockHttpClient(), 'csk-1234567890abcdef'); + + self::assertTrue($client->supports(new Model(Model::GPT_OSS_120B))); + } + + public function testItSuccessfullyInvokesTheModel() + { + $expectedResponse = [ + 'model' => 'llama-3.3-70b', + 'input' => [ + 'messages' => [ + ['role' => 'user', 'content' => 'Hello, world!'], + ], + ], + 'temperature' => 0.5, + ]; + $httpClient = new MockHttpClient( + new JsonMockResponse($expectedResponse), + ); + + $client = new ModelClient($httpClient, 'csk-1234567890abcdef'); + + $payload = [ + 'messages' => [ + ['role' => 'user', 'content' => 'Hello, world!'], + ], + ]; + + $result = $client->request(new Model(Model::LLAMA_3_3_70B), $payload); + $data = $result->getData(); + $info = $result->getObject()->getInfo(); + + self::assertNotEmpty($data); + self::assertNotEmpty($info); + self::assertSame('POST', $info['http_method']); + self::assertSame('https://api.cerebras.ai/v1/chat/completions', $info['url']); + self::assertSame($expectedResponse, $data); + } +} diff --git a/src/platform/tests/Bridge/Cerebras/ResultConverterTest.php b/src/platform/tests/Bridge/Cerebras/ResultConverterTest.php new file mode 100644 index 000000000..a7185a137 --- /dev/null +++ b/src/platform/tests/Bridge/Cerebras/ResultConverterTest.php @@ -0,0 +1,37 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Tests\Bridge\Cerebras; + +use PHPUnit\Framework\Attributes\CoversClass; +use PHPUnit\Framework\Attributes\Small; +use PHPUnit\Framework\Attributes\UsesClass; +use PHPUnit\Framework\TestCase; +use Symfony\AI\Platform\Bridge\Cerebras\Model; +use Symfony\AI\Platform\Bridge\Cerebras\ModelClient; +use Symfony\AI\Platform\Bridge\Cerebras\ResultConverter; +use Symfony\Component\HttpClient\MockHttpClient; + +/** + * @author Junaid Farooq + */ +#[CoversClass(ResultConverter::class)] +#[UsesClass(Model::class)] +#[Small] +class ResultConverterTest extends TestCase +{ + public function testItSupportsTheCorrectModel() + { + $client = new ModelClient(new MockHttpClient(), 'csk-1234567890abcdef'); + + $this->assertTrue($client->supports(new Model(Model::GPT_OSS_120B))); + } +} diff --git a/src/platform/tests/Bridge/Ollama/Contract/AssistantMessageNormalizerTest.php b/src/platform/tests/Bridge/Ollama/Contract/AssistantMessageNormalizerTest.php index 743d9a352..465b117c0 100644 --- a/src/platform/tests/Bridge/Ollama/Contract/AssistantMessageNormalizerTest.php +++ b/src/platform/tests/Bridge/Ollama/Contract/AssistantMessageNormalizerTest.php @@ -72,6 +72,7 @@ public static function normalizeDataProvider(): iterable new AssistantMessage('Hello'), [ 'role' => Role::Assistant, + 'content' => 'Hello', 'tool_calls' => [], ], ]; @@ -80,6 +81,7 @@ public static function normalizeDataProvider(): iterable new AssistantMessage(toolCalls: [new ToolCall('id1', 'function1', ['param' => 'value'])]), [ 'role' => Role::Assistant, + 'content' => '', 'tool_calls' => [ [ 'type' => 'function', @@ -96,6 +98,7 @@ public static function normalizeDataProvider(): iterable new AssistantMessage(toolCalls: [new ToolCall('id1', 'function1', [])]), [ 'role' => Role::Assistant, + 'content' => '', 'tool_calls' => [ [ 'type' => 'function', @@ -115,6 +118,37 @@ public static function normalizeDataProvider(): iterable ]), [ 'role' => Role::Assistant, + 'content' => '', + 'tool_calls' => [ + [ + 'type' => 'function', + 'function' => [ + 'name' => 'function1', + 'arguments' => ['param1' => 'value1'], + ], + ], + [ + 'type' => 'function', + 'function' => [ + 'name' => 'function2', + 'arguments' => ['param2' => 'value2'], + ], + ], + ], + ], + ]; + + yield 'assistant message with tool calls and content' => [ + new AssistantMessage( + content: 'Hello', + toolCalls: [ + new ToolCall('id1', 'function1', ['param1' => 'value1']), + new ToolCall('id2', 'function2', ['param2' => 'value2']), + ] + ), + [ + 'role' => Role::Assistant, + 'content' => 'Hello', 'tool_calls' => [ [ 'type' => 'function', diff --git a/src/platform/tests/Bridge/Ollama/OllamaClientTest.php b/src/platform/tests/Bridge/Ollama/OllamaClientTest.php new file mode 100644 index 000000000..fcb236647 --- /dev/null +++ b/src/platform/tests/Bridge/Ollama/OllamaClientTest.php @@ -0,0 +1,90 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Platform\Tests\Bridge\Ollama; + +use PHPUnit\Framework\Attributes\CoversClass; +use PHPUnit\Framework\Attributes\UsesClass; +use PHPUnit\Framework\TestCase; +use Symfony\AI\Platform\Bridge\Ollama\Ollama; +use Symfony\AI\Platform\Bridge\Ollama\OllamaClient; +use Symfony\AI\Platform\Model; +use Symfony\Component\HttpClient\MockHttpClient; +use Symfony\Component\HttpClient\Response\JsonMockResponse; + +#[CoversClass(OllamaClient::class)] +#[UsesClass(Ollama::class)] +#[UsesClass(Model::class)] +final class OllamaClientTest extends TestCase +{ + public function testSupportsModel() + { + $client = new OllamaClient(new MockHttpClient(), 'http://localhost:1234'); + + $this->assertTrue($client->supports(new Ollama())); + $this->assertFalse($client->supports(new Model('any-model'))); + } + + public function testOutputStructureIsSupported() + { + $httpClient = new MockHttpClient([ + new JsonMockResponse([ + 'capabilities' => ['completion', 'tools'], + ]), + new JsonMockResponse([ + 'model' => 'foo', + 'response' => [ + 'age' => 22, + 'available' => true, + ], + 'done' => true, + ]), + ], 'http://127.0.0.1:1234'); + + $client = new OllamaClient($httpClient, 'http://127.0.0.1:1234'); + $response = $client->request(new Ollama(), [ + 'messages' => [ + [ + 'role' => 'user', + 'content' => 'Ollama is 22 years old and is busy saving the world. Respond using JSON', + ], + ], + 'model' => 'llama3.2', + ], [ + 'response_format' => [ + 'type' => 'json_schema', + 'json_schema' => [ + 'name' => 'clock', + 'strict' => true, + 'schema' => [ + 'type' => 'object', + 'properties' => [ + 'age' => ['type' => 'integer'], + 'available' => ['type' => 'boolean'], + ], + 'required' => ['age', 'available'], + 'additionalProperties' => false, + ], + ], + ], + ]); + + $this->assertSame(2, $httpClient->getRequestsCount()); + $this->assertSame([ + 'model' => 'foo', + 'response' => [ + 'age' => 22, + 'available' => true, + ], + 'done' => true, + ], $response->getData()); + } +} diff --git a/src/platform/tests/Bridge/Ollama/OllamaResultConverterTest.php b/src/platform/tests/Bridge/Ollama/OllamaResultConverterTest.php index a5513fea9..791206200 100644 --- a/src/platform/tests/Bridge/Ollama/OllamaResultConverterTest.php +++ b/src/platform/tests/Bridge/Ollama/OllamaResultConverterTest.php @@ -20,9 +20,13 @@ use Symfony\AI\Platform\Exception\RuntimeException; use Symfony\AI\Platform\Model; use Symfony\AI\Platform\Result\InMemoryRawResult; +use Symfony\AI\Platform\Result\RawHttpResult; use Symfony\AI\Platform\Result\TextResult; use Symfony\AI\Platform\Result\ToolCall; use Symfony\AI\Platform\Result\ToolCallResult; +use Symfony\AI\Platform\Result\VectorResult; +use Symfony\AI\Platform\Vector\Vector; +use Symfony\Contracts\HttpClient\ResponseInterface; #[CoversClass(OllamaResultConverter::class)] #[Small] @@ -30,6 +34,8 @@ #[UsesClass(TextResult::class)] #[UsesClass(ToolCall::class)] #[UsesClass(ToolCallResult::class)] +#[UsesClass(Vector::class)] +#[UsesClass(VectorResult::class)] final class OllamaResultConverterTest extends TestCase { public function testSupportsLlamaModel() @@ -143,4 +149,29 @@ public function testThrowsExceptionWhenNoContent() $converter->convert($rawResult); } + + public function testItConvertsAResponseToAVectorResult() + { + $result = $this->createStub(ResponseInterface::class); + $result + ->method('toArray') + ->willReturn([ + 'model' => 'all-minilm', + 'embeddings' => [ + [0.3, 0.4, 0.4], + [0.0, 0.0, 0.2], + ], + 'total_duration' => 14143917, + 'load_duration' => 1019500, + 'prompt_eval_count' => 8, + ]); + + $vectorResult = (new OllamaResultConverter())->convert(new RawHttpResult($result)); + $convertedContent = $vectorResult->getContent(); + + $this->assertCount(2, $convertedContent); + + $this->assertSame([0.3, 0.4, 0.4], $convertedContent[0]->getData()); + $this->assertSame([0.0, 0.0, 0.2], $convertedContent[1]->getData()); + } } diff --git a/src/platform/tests/Bridge/Ollama/OllamaTest.php b/src/platform/tests/Bridge/Ollama/OllamaTest.php index 7c3133522..e98d1efd3 100644 --- a/src/platform/tests/Bridge/Ollama/OllamaTest.php +++ b/src/platform/tests/Bridge/Ollama/OllamaTest.php @@ -44,6 +44,17 @@ public function testModelsWithoutToolCallingCapability(string $modelName) ); } + #[DataProvider('provideModelsWithMultipleInputCapabilities')] + public function testModelsWithMultipleInputCapabilities(string $modelName) + { + $model = new Ollama($modelName); + + $this->assertTrue( + $model->supports(Capability::INPUT_MULTIPLE), + \sprintf('Model "%s" should not support multiple input capabilities', $modelName) + ); + } + /** * @return iterable */ @@ -82,4 +93,14 @@ public static function provideModelsWithoutToolCallingCapability(): iterable yield 'llava' => [Ollama::LLAVA]; yield 'qwen2.5vl' => [Ollama::QWEN_2_5_VL]; // This has 'vl' suffix which doesn't match the pattern } + + /** + * @return iterable + */ + public static function provideModelsWithMultipleInputCapabilities(): iterable + { + yield 'nomic-embed-text' => [Ollama::NOMIC_EMBED_TEXT]; + yield 'bge-m3' => [Ollama::BGE_M3]; + yield 'all-minilm' => [Ollama::ALL_MINILM]; + } } diff --git a/src/platform/tests/Bridge/OpenAi/DallE/Base64ImageTest.php b/src/platform/tests/Bridge/OpenAi/DallE/Base64ImageTest.php index 04c7dc40f..4b5012924 100644 --- a/src/platform/tests/Bridge/OpenAi/DallE/Base64ImageTest.php +++ b/src/platform/tests/Bridge/OpenAi/DallE/Base64ImageTest.php @@ -30,8 +30,8 @@ public function testItCreatesBase64Image() public function testItThrowsExceptionWhenBase64ImageIsEmpty() { - self::expectException(\InvalidArgumentException::class); - self::expectExceptionMessage('The base64 encoded image generated must be given.'); + $this->expectException(\InvalidArgumentException::class); + $this->expectExceptionMessage('The base64 encoded image generated must be given.'); new Base64Image(''); } diff --git a/src/platform/tests/Bridge/OpenAi/DallE/UrlImageTest.php b/src/platform/tests/Bridge/OpenAi/DallE/UrlImageTest.php index 0060cc8f7..99a75a25b 100644 --- a/src/platform/tests/Bridge/OpenAi/DallE/UrlImageTest.php +++ b/src/platform/tests/Bridge/OpenAi/DallE/UrlImageTest.php @@ -29,8 +29,8 @@ public function testItCreatesUrlImage() public function testItThrowsExceptionWhenUrlIsEmpty() { - self::expectException(\InvalidArgumentException::class); - self::expectExceptionMessage('The image url must be given.'); + $this->expectException(\InvalidArgumentException::class); + $this->expectExceptionMessage('The image url must be given.'); new UrlImage(''); } diff --git a/src/platform/tests/Bridge/OpenAi/Gpt/ResultConverterTest.php b/src/platform/tests/Bridge/OpenAi/Gpt/ResultConverterTest.php index 8f0a6f2fe..032769f4f 100644 --- a/src/platform/tests/Bridge/OpenAi/Gpt/ResultConverterTest.php +++ b/src/platform/tests/Bridge/OpenAi/Gpt/ResultConverterTest.php @@ -149,8 +149,8 @@ public function getResponse(): ResponseInterface ]; }); - self::expectException(ContentFilterException::class); - self::expectExceptionMessage('Content was filtered'); + $this->expectException(ContentFilterException::class); + $this->expectExceptionMessage('Content was filtered'); $converter->convert(new RawHttpResult($httpResponse)); } @@ -161,8 +161,8 @@ public function testThrowsExceptionWhenNoChoices() $httpResponse = self::createMock(ResponseInterface::class); $httpResponse->method('toArray')->willReturn([]); - self::expectException(RuntimeException::class); - self::expectExceptionMessage('Response does not contain choices'); + $this->expectException(RuntimeException::class); + $this->expectExceptionMessage('Response does not contain choices'); $converter->convert(new RawHttpResult($httpResponse)); } @@ -183,8 +183,8 @@ public function testThrowsExceptionForUnsupportedFinishReason() ], ]); - self::expectException(RuntimeException::class); - self::expectExceptionMessage('Unsupported finish reason "unsupported_reason"'); + $this->expectException(RuntimeException::class); + $this->expectExceptionMessage('Unsupported finish reason "unsupported_reason"'); $converter->convert(new RawHttpResult($httpResponse)); } diff --git a/src/platform/tests/Contract/JsonSchema/Attribute/ToolParameterTest.php b/src/platform/tests/Contract/JsonSchema/Attribute/ToolParameterTest.php index 1614f4d6c..e80b9c78d 100644 --- a/src/platform/tests/Contract/JsonSchema/Attribute/ToolParameterTest.php +++ b/src/platform/tests/Contract/JsonSchema/Attribute/ToolParameterTest.php @@ -28,7 +28,7 @@ public function testValidEnum() public function testInvalidEnumContainsNonString() { - self::expectException(InvalidArgumentException::class); + $this->expectException(InvalidArgumentException::class); $enum = ['value1', 2]; new With(enum: $enum); } @@ -42,7 +42,7 @@ public function testValidConstString() public function testInvalidConstEmptyString() { - self::expectException(InvalidArgumentException::class); + $this->expectException(InvalidArgumentException::class); $const = ' '; new With(const: $const); } @@ -56,7 +56,7 @@ public function testValidPattern() public function testInvalidPatternEmptyString() { - self::expectException(InvalidArgumentException::class); + $this->expectException(InvalidArgumentException::class); $pattern = ' '; new With(pattern: $pattern); } @@ -70,7 +70,7 @@ public function testValidMinLength() public function testInvalidMinLengthNegative() { - self::expectException(InvalidArgumentException::class); + $this->expectException(InvalidArgumentException::class); new With(minLength: -1); } @@ -85,7 +85,7 @@ public function testValidMinLengthAndMaxLength() public function testInvalidMaxLengthLessThanMinLength() { - self::expectException(InvalidArgumentException::class); + $this->expectException(InvalidArgumentException::class); new With(minLength: 10, maxLength: 5); } @@ -98,7 +98,7 @@ public function testValidMinimum() public function testInvalidMinimumNegative() { - self::expectException(InvalidArgumentException::class); + $this->expectException(InvalidArgumentException::class); new With(minimum: -1); } @@ -111,7 +111,7 @@ public function testValidMultipleOf() public function testInvalidMultipleOfNegative() { - self::expectException(InvalidArgumentException::class); + $this->expectException(InvalidArgumentException::class); new With(multipleOf: -5); } @@ -126,7 +126,7 @@ public function testValidExclusiveMinimumAndMaximum() public function testInvalidExclusiveMaximumLessThanExclusiveMinimum() { - self::expectException(InvalidArgumentException::class); + $this->expectException(InvalidArgumentException::class); new With(exclusiveMinimum: 10, exclusiveMaximum: 5); } @@ -141,7 +141,7 @@ public function testValidMinItemsAndMaxItems() public function testInvalidMaxItemsLessThanMinItems() { - self::expectException(InvalidArgumentException::class); + $this->expectException(InvalidArgumentException::class); new With(minItems: 5, maxItems: 1); } @@ -153,7 +153,7 @@ public function testValidUniqueItemsTrue() public function testInvalidUniqueItemsFalse() { - self::expectException(InvalidArgumentException::class); + $this->expectException(InvalidArgumentException::class); new With(uniqueItems: false); } @@ -168,7 +168,7 @@ public function testValidMinContainsAndMaxContains() public function testInvalidMaxContainsLessThanMinContains() { - self::expectException(InvalidArgumentException::class); + $this->expectException(InvalidArgumentException::class); new With(minContains: 3, maxContains: 1); } @@ -189,7 +189,7 @@ public function testValidMinPropertiesAndMaxProperties() public function testInvalidMaxPropertiesLessThanMinProperties() { - self::expectException(InvalidArgumentException::class); + $this->expectException(InvalidArgumentException::class); new With(minProperties: 5, maxProperties: 1); } @@ -228,7 +228,7 @@ enum: ['value1', 'value2'], public function testInvalidCombination() { - self::expectException(InvalidArgumentException::class); + $this->expectException(InvalidArgumentException::class); new With(minLength: -1, maxLength: -2); } } diff --git a/src/platform/tests/InMemoryPlatformTest.php b/src/platform/tests/InMemoryPlatformTest.php index 4cda27495..2ede09f2d 100644 --- a/src/platform/tests/InMemoryPlatformTest.php +++ b/src/platform/tests/InMemoryPlatformTest.php @@ -13,6 +13,8 @@ use PHPUnit\Framework\TestCase; use Symfony\AI\Platform\InMemoryPlatform; use Symfony\AI\Platform\Model; +use Symfony\AI\Platform\Result\VectorResult; +use Symfony\AI\Platform\Vector\Vector; #[CoversClass(InMemoryPlatform::class)] class InMemoryPlatformTest extends TestCase @@ -37,4 +39,15 @@ public function testPlatformInvokeWithCallableResult() $this->assertSame('DYNAMIC TEXT', $result->asText()); } + + public function testPlatformInvokeWithVectorResultResponse() + { + $platform = new InMemoryPlatform( + fn () => new VectorResult(new Vector([0.1, 0.1, 0.5])) + ); + + $result = $platform->invoke(new Model('test'), 'dynamic text'); + + $this->assertEquals([0.1, 0.1, 0.5], $result->asVectors()[0]->getData()); + } } diff --git a/src/platform/tests/Message/Content/AudioTest.php b/src/platform/tests/Message/Content/AudioTest.php index 99d34c8a1..9b04e6997 100644 --- a/src/platform/tests/Message/Content/AudioTest.php +++ b/src/platform/tests/Message/Content/AudioTest.php @@ -39,8 +39,8 @@ public function testFromDataUrlWithValidUrl() public function testFromDataUrlWithInvalidUrl() { - self::expectException(\InvalidArgumentException::class); - self::expectExceptionMessage('Invalid audio data URL format.'); + $this->expectException(\InvalidArgumentException::class); + $this->expectExceptionMessage('Invalid audio data URL format.'); Audio::fromDataUrl('invalid-url'); } @@ -55,8 +55,8 @@ public function testFromFileWithValidPath() public function testFromFileWithInvalidPath() { - self::expectException(\InvalidArgumentException::class); - self::expectExceptionMessage('The file "foo.mp3" does not exist or is not readable.'); + $this->expectException(\InvalidArgumentException::class); + $this->expectExceptionMessage('The file "foo.mp3" does not exist or is not readable.'); Audio::fromFile('foo.mp3'); } diff --git a/src/platform/tests/Message/Content/BinaryTest.php b/src/platform/tests/Message/Content/BinaryTest.php index 3521563bc..8ad3ec014 100644 --- a/src/platform/tests/Message/Content/BinaryTest.php +++ b/src/platform/tests/Message/Content/BinaryTest.php @@ -35,8 +35,8 @@ public function testCreateFromDataUrl() public function testThrowsExceptionForInvalidDataUrl() { - self::expectException(InvalidArgumentException::class); - self::expectExceptionMessage('Invalid audio data URL format.'); + $this->expectException(InvalidArgumentException::class); + $this->expectExceptionMessage('Invalid audio data URL format.'); File::fromDataUrl('invalid-data-url'); } @@ -77,7 +77,7 @@ public static function provideExistingFiles(): iterable public function testThrowsExceptionForNonExistentFile() { - self::expectException(\InvalidArgumentException::class); + $this->expectException(\InvalidArgumentException::class); File::fromFile('/non/existent/file.jpg'); } diff --git a/src/platform/tests/Message/Content/ImageTest.php b/src/platform/tests/Message/Content/ImageTest.php index 6fee9cdeb..b8a14a3e7 100644 --- a/src/platform/tests/Message/Content/ImageTest.php +++ b/src/platform/tests/Message/Content/ImageTest.php @@ -34,7 +34,7 @@ public function testWithValidFile() public function testFromBinaryWithInvalidFile() { - self::expectExceptionMessage('The file "foo.jpg" does not exist or is not readable.'); + $this->expectExceptionMessage('The file "foo.jpg" does not exist or is not readable.'); Image::fromFile('foo.jpg'); } diff --git a/src/platform/tests/Result/BaseResultTest.php b/src/platform/tests/Result/BaseResultTest.php index a216795c8..c7493bb6a 100644 --- a/src/platform/tests/Result/BaseResultTest.php +++ b/src/platform/tests/Result/BaseResultTest.php @@ -55,7 +55,7 @@ public function testItCanBeEnrichedWithARawResponse() public function testItThrowsAnExceptionWhenSettingARawResponseTwice() { - self::expectException(RawResultAlreadySetException::class); + $this->expectException(RawResultAlreadySetException::class); $result = $this->createResult(); $rawResult = $this->createRawResult(); diff --git a/src/platform/tests/Result/ChoiceResultTest.php b/src/platform/tests/Result/ChoiceResultTest.php index c683468f8..b50097e89 100644 --- a/src/platform/tests/Result/ChoiceResultTest.php +++ b/src/platform/tests/Result/ChoiceResultTest.php @@ -35,8 +35,8 @@ public function testChoiceResultCreation() public function testChoiceResultWithNoChoices() { - self::expectException(InvalidArgumentException::class); - self::expectExceptionMessage('A choice result must contain at least two results.'); + $this->expectException(InvalidArgumentException::class); + $this->expectExceptionMessage('A choice result must contain at least two results.'); new ChoiceResult(); } diff --git a/src/platform/tests/Result/RawResultAwareTraitTest.php b/src/platform/tests/Result/RawResultAwareTraitTest.php index b79596ac1..677a74b43 100644 --- a/src/platform/tests/Result/RawResultAwareTraitTest.php +++ b/src/platform/tests/Result/RawResultAwareTraitTest.php @@ -36,7 +36,7 @@ public function testItCanBeEnrichedWithARawResponse() public function testItThrowsAnExceptionWhenSettingARawResponseTwice() { - self::expectException(RawResultAlreadySetException::class); + $this->expectException(RawResultAlreadySetException::class); $result = $this->createTestClass(); $rawResponse = self::createMock(SymfonyHttpResponse::class); diff --git a/src/platform/tests/Result/TollCallResultTest.php b/src/platform/tests/Result/TollCallResultTest.php index 9bd913558..90bedd90e 100644 --- a/src/platform/tests/Result/TollCallResultTest.php +++ b/src/platform/tests/Result/TollCallResultTest.php @@ -26,8 +26,8 @@ final class TollCallResultTest extends TestCase { public function testThrowsIfNoToolCall() { - self::expectException(InvalidArgumentException::class); - self::expectExceptionMessage('Response must have at least one tool call.'); + $this->expectException(InvalidArgumentException::class); + $this->expectExceptionMessage('Response must have at least one tool call.'); new ToolCallResult(); } diff --git a/src/platform/tests/Vector/NullVectorTest.php b/src/platform/tests/Vector/NullVectorTest.php index 77c695f9c..79c1dca33 100644 --- a/src/platform/tests/Vector/NullVectorTest.php +++ b/src/platform/tests/Vector/NullVectorTest.php @@ -27,14 +27,14 @@ public function testImplementsInterface() public function testGetDataThrowsOnAccess() { - self::expectException(RuntimeException::class); + $this->expectException(RuntimeException::class); (new NullVector())->getData(); } public function testGetDimensionsThrowsOnAccess() { - self::expectException(RuntimeException::class); + $this->expectException(RuntimeException::class); (new NullVector())->getDimensions(); } diff --git a/src/store/CHANGELOG.md b/src/store/CHANGELOG.md index 0fa347f23..e395af555 100644 --- a/src/store/CHANGELOG.md +++ b/src/store/CHANGELOG.md @@ -28,23 +28,25 @@ CHANGELOG - Orchestrates document processing pipeline - Accepts TextDocuments, vectorizes and stores in chunks - Configurable batch processing - * Add `InMemoryStore` implementation with multiple distance algorithms: + * Add `InMemoryStore` and `CacheStore` implementations with multiple distance algorithms: - Cosine similarity - Angular distance - Euclidean distance - Manhattan distance - Chebyshev distance * Add store bridge implementations: - - PostgreSQL with pgvector extension - - MariaDB - - MongoDB - Azure AI Search - - Meilisearch - ChromaDB + - ClickHouse + - MariaDB + - Meilisearch + - MongoDB + - Neo4j - Pinecone + - PostgreSQL with pgvector extension - Qdrant - SurrealDB - - Neo4j + - Typesense * Add Retrieval Augmented Generation (RAG) support: - Document embedding storage - Similarity search for relevant documents diff --git a/src/store/composer.json b/src/store/composer.json index 9b89bb867..5ff1e5f2c 100644 --- a/src/store/composer.json +++ b/src/store/composer.json @@ -40,7 +40,8 @@ "mongodb/mongodb": "^1.21 || ^2.0", "phpstan/phpstan": "^2.0", "phpunit/phpunit": "^11.5", - "probots-io/pinecone-php": "^1.0" + "probots-io/pinecone-php": "^1.0", + "symfony/cache": "^7.3" }, "config": { "sort-packages": true diff --git a/src/store/doc/index.rst b/src/store/doc/index.rst index 91fd26bca..7ca41cf2e 100644 --- a/src/store/doc/index.rst +++ b/src/store/doc/index.rst @@ -41,10 +41,18 @@ You can find more advanced usage in combination with an Agent using the store fo * `Similarity Search with Meilisearch (RAG)`_ * `Similarity Search with memory storage (RAG)`_ * `Similarity Search with MongoDB (RAG)`_ +* `Similarity Search with Neo4j (RAG)`_ * `Similarity Search with Pinecone (RAG)`_ +* `Similarity Search with PSR-6 Cache (RAG)`_ * `Similarity Search with Qdrant (RAG)`_ * `Similarity Search with SurrealDB (RAG)`_ -* `Similarity Search with Neo4j (RAG)`_ +* `Similarity Search with Typesense (RAG)`_ + +.. note:: + + Both `InMemory` and `PSR-6 cache` vector stores will load all the data into the + memory of the PHP process. They can be used only the amount of data fits in the + PHP memory limit, typically for testing. Supported Stores ---------------- @@ -55,11 +63,13 @@ Supported Stores * `MariaDB`_ (requires `ext-pdo`) * `Meilisearch`_ * `MongoDB Atlas`_ (requires `mongodb/mongodb` as additional dependency) +* `Neo4j`_ * `Pinecone`_ (requires `probots-io/pinecone-php` as additional dependency) * `Postgres`_ (requires `ext-pdo`) +* `PSR-6 Cache`_ * `Qdrant`_ * `SurrealDB`_ -* `Neo4j`_ +* `Typesense`_ .. note:: @@ -93,14 +103,16 @@ This leads to a store implementing two methods:: } .. _`Retrieval Augmented Generation`: https://de.wikipedia.org/wiki/Retrieval-Augmented_Generation -.. _`Similarity Search with MariaDB (RAG)`: https://github.com/symfony/ai/blob/main/examples/rag/mariadb.php +.. _`Similarity Search with MariaDB (RAG)`: https://github.com/symfony/ai/blob/main/examples/rag/mariadb-gemini.php .. _`Similarity Search with MongoDB (RAG)`: https://github.com/symfony/ai/blob/main/examples/rag/mongodb.php -.. _`Similarity Search with Pinecone (RAG)`: https://github.com/symfony/ai/blob/main/examples/rag/pinecone.php .. _`Similarity Search with Meilisearch (RAG)`: https://github.com/symfony/ai/blob/main/examples/rag/meilisearch.php -.. _`Similarity Search with SurrealDB (RAG)`: https://github.com/symfony/ai/blob/main/examples/rag/surrealdb.php .. _`Similarity Search with memory storage (RAG)`: https://github.com/symfony/ai/blob/main/examples/rag/in-memory.php -.. _`Similarity Search with Qdrant (RAG)`: https://github.com/symfony/ai/blob/main/examples/rag/qdrant.php .. _`Similarity Search with Neo4j (RAG)`: https://github.com/symfony/ai/blob/main/examples/rag/neo4j.php +.. _`Similarity Search with Pinecone (RAG)`: https://github.com/symfony/ai/blob/main/examples/rag/pinecone.php +.. _`Similarity Search with PSR-6 Cache (RAG)`: https://github.com/symfony/ai/blob/main/examples/rag/cache.php +.. _`Similarity Search with Qdrant (RAG)`: https://github.com/symfony/ai/blob/main/examples/rag/qdrant.php +.. _`Similarity Search with SurrealDB (RAG)`: https://github.com/symfony/ai/blob/main/examples/rag/surrealdb.php +.. _`Similarity Search with Typesense (RAG)`: https://github.com/symfony/ai/blob/main/examples/rag/typesense.php .. _`Azure AI Search`: https://azure.microsoft.com/products/ai-services/ai-search .. _`Chroma`: https://www.trychroma.com/ .. _`MariaDB`: https://mariadb.org/projects/mariadb-vector/ @@ -112,4 +124,6 @@ This leads to a store implementing two methods:: .. _`InMemory`: https://www.php.net/manual/en/language.types.array.php .. _`Qdrant`: https://qdrant.tech/ .. _`Neo4j`: https://neo4j.com/ +.. _`Typesense`: https://typesense.org/ .. _`GitHub`: https://github.com/symfony/ai/issues/16 +.. _`PSR-6 Cache`: https://www.php-fig.org/psr/psr-6/ diff --git a/src/store/src/Bridge/ClickHouse/Store.php b/src/store/src/Bridge/ClickHouse/Store.php new file mode 100644 index 000000000..e600985b2 --- /dev/null +++ b/src/store/src/Bridge/ClickHouse/Store.php @@ -0,0 +1,179 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Store\Bridge\ClickHouse; + +use Symfony\AI\Platform\Vector\Vector; +use Symfony\AI\Platform\Vector\VectorInterface; +use Symfony\AI\Store\Document\Metadata; +use Symfony\AI\Store\Document\VectorDocument; +use Symfony\AI\Store\Exception\RuntimeException; +use Symfony\AI\Store\InitializableStoreInterface; +use Symfony\AI\Store\VectorStoreInterface; +use Symfony\Component\Uid\Uuid; +use Symfony\Contracts\HttpClient\HttpClientInterface; +use Symfony\Contracts\HttpClient\ResponseInterface; + +/** + * @author Grégoire Pineau + */ +class Store implements VectorStoreInterface, InitializableStoreInterface +{ + public function __construct( + private readonly HttpClientInterface $httpClient, + private readonly string $databaseName = 'default', + private readonly string $tableName = 'embedding', + ) { + } + + public function initialize(array $options = []): void + { + $sql = <<<'SQL' + CREATE TABLE IF NOT EXISTS {{ table }} ( + id UUID, + metadata String, + embedding Array(Float32), + ) ENGINE = MergeTree() + ORDER BY id + SQL; + + $this->execute('POST', $sql); + } + + public function add(VectorDocument ...$documents): void + { + $rows = []; + + foreach ($documents as $document) { + $rows[] = $this->formatVectorDocument($document); + } + + $this->insertBatch($rows); + } + + /** + * @return array + */ + protected function formatVectorDocument(VectorDocument $document): array + { + return [ + 'id' => $document->id->toRfc4122(), + 'metadata' => json_encode($document->metadata->getArrayCopy(), \JSON_THROW_ON_ERROR), + 'embedding' => $document->vector->getData(), + ]; + } + + public function query(Vector $vector, array $options = [], ?float $minScore = null): array + { + $sql = <<<'SQL' + SELECT + id, + embedding, + metadata, + cosineDistance(embedding, {query_vector:Array(Float32)}) as score + FROM {{ table }} + WHERE length(embedding) = length({query_vector:Array(Float32)}) {{ where }} + ORDER BY score ASC + LIMIT {limit:UInt32} + SQL; + + if (isset($options['where'])) { + $sql = str_replace('{{ where }}', 'AND '.$options['where'], $sql); + } else { + $sql = str_replace('{{ where }}', '', $sql); + } + + $results = $this + ->execute('GET', $sql, [ + 'query_vector' => $this->toClickHouseVector($vector), + 'limit' => $options['limit'] ?? 5, + ...$options['params'] ?? [], + ]) + ->toArray()['data'] + ; + + $documents = []; + foreach ($results as $result) { + $documents[] = new VectorDocument( + id: Uuid::fromString($result['id']), + vector: new Vector($result['embedding']), + metadata: new Metadata(json_decode($result['metadata'] ?? '{}', true, 512, \JSON_THROW_ON_ERROR)), + score: $result['score'], + ); + } + + return $documents; + } + + /** + * @param array $params + */ + protected function execute(string $method, string $sql, array $params = []): ResponseInterface + { + $sql = str_replace('{{ table }}', $this->tableName, $sql); + + $options = [ + 'query' => [ + 'query' => $sql, + 'database' => $this->databaseName, + 'default_format' => 'JSON', + ], + ]; + + foreach ($params as $key => $value) { + $options['query']['param_'.$key] = $value; + } + + return $this->httpClient->request($method, '/', $options); + } + + /** + * @param array> $rows + */ + private function insertBatch(array $rows): void + { + if (!$rows) { + return; + } + + $sql = 'INSERT INTO {{ table }} FORMAT JSONEachRow'; + $sql = str_replace('{{ table }}', $this->tableName, $sql); + + $jsonData = ''; + foreach ($rows as $row) { + $jsonData .= json_encode($row)."\n"; + } + + $options = [ + 'query' => [ + 'query' => $sql, + 'database' => $this->databaseName, + ], + 'body' => $jsonData, + 'headers' => [ + 'Content-Type' => 'application/json', + ], + ]; + + $response = $this->httpClient->request('POST', '/', $options); + + if (200 !== $response->getStatusCode()) { + $content = $response->getContent(false); + + throw new RuntimeException("Could not insert data into ClickHouse. Http status code: {$response->getStatusCode()}. Response: {$content}."); + } + } + + private function toClickHouseVector(VectorInterface $vector): string + { + return '['.implode(',', $vector->getData()).']'; + } +} diff --git a/src/store/src/Bridge/Typesense/Store.php b/src/store/src/Bridge/Typesense/Store.php new file mode 100644 index 000000000..b2c49ecd4 --- /dev/null +++ b/src/store/src/Bridge/Typesense/Store.php @@ -0,0 +1,133 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Store\Bridge\Typesense; + +use Symfony\AI\Platform\Vector\NullVector; +use Symfony\AI\Platform\Vector\Vector; +use Symfony\AI\Store\Document\Metadata; +use Symfony\AI\Store\Document\VectorDocument; +use Symfony\AI\Store\Exception\InvalidArgumentException; +use Symfony\AI\Store\InitializableStoreInterface; +use Symfony\AI\Store\VectorStoreInterface; +use Symfony\Component\Uid\Uuid; +use Symfony\Contracts\HttpClient\HttpClientInterface; + +/** + * @author Guillaume Loulier + */ +final readonly class Store implements InitializableStoreInterface, VectorStoreInterface +{ + public function __construct( + private HttpClientInterface $httpClient, + private string $endpointUrl, + #[\SensitiveParameter] private string $apiKey, + #[\SensitiveParameter] private string $collection, + private string $vectorFieldName = '_vectors', + private int $embeddingsDimension = 1536, + ) { + } + + public function add(VectorDocument ...$documents): void + { + foreach ($documents as $document) { + $this->request('POST', \sprintf('collections/%s/documents', $this->collection), $this->convertToIndexableArray($document)); + } + } + + public function query(Vector $vector, array $options = []): array + { + $documents = $this->request('POST', 'multi_search', [ + 'searches' => [ + [ + 'collection' => $this->collection, + 'q' => '*', + 'vector_query' => \sprintf('%s:([%s], k:%d)', $this->vectorFieldName, implode(', ', $vector->getData()), $options['k'] ?? 10), + ], + ], + ]); + + return array_map($this->convertToVectorDocument(...), $documents['results'][0]['hits']); + } + + public function initialize(array $options = []): void + { + if ([] !== $options) { + throw new InvalidArgumentException('No supported options.'); + } + + $this->request('POST', 'collections', [ + 'name' => $this->collection, + 'fields' => [ + [ + 'name' => 'id', + 'type' => 'string', + ], + [ + 'name' => $this->vectorFieldName, + 'type' => 'float[]', + 'num_dim' => $this->embeddingsDimension, + ], + [ + 'name' => 'metadata', + 'type' => 'string', + ], + ], + ]); + } + + /** + * @param array $payload + * + * @return array + */ + private function request(string $method, string $endpoint, array $payload): array + { + $url = \sprintf('%s/%s', $this->endpointUrl, $endpoint); + $result = $this->httpClient->request($method, $url, [ + 'headers' => [ + 'X-TYPESENSE-API-KEY' => $this->apiKey, + ], + 'json' => $payload, + ]); + + return $result->toArray(); + } + + /** + * @return array + */ + private function convertToIndexableArray(VectorDocument $document): array + { + return [ + 'id' => $document->id->toRfc4122(), + $this->vectorFieldName => $document->vector->getData(), + 'metadata' => json_encode($document->metadata->getArrayCopy()), + ]; + } + + /** + * @param array $data + */ + private function convertToVectorDocument(array $data): VectorDocument + { + $document = $data['document'] ?? throw new InvalidArgumentException('Missing "document" field in the document data.'); + + $id = $document['id'] ?? throw new InvalidArgumentException('Missing "id" field in the document data.'); + + $vector = !\array_key_exists($this->vectorFieldName, $document) || null === $document[$this->vectorFieldName] + ? new NullVector() : new Vector($document[$this->vectorFieldName]); + + $score = $data['vector_distance'] ?? null; + + return new VectorDocument(Uuid::fromString($id), $vector, new Metadata(json_decode($document['metadata'], true)), $score); + } +} diff --git a/src/store/src/CacheStore.php b/src/store/src/CacheStore.php new file mode 100644 index 000000000..1223eda2d --- /dev/null +++ b/src/store/src/CacheStore.php @@ -0,0 +1,74 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Store; + +use Psr\Cache\CacheItemPoolInterface; +use Symfony\AI\Platform\Vector\Vector; +use Symfony\AI\Store\Document\Metadata; +use Symfony\AI\Store\Document\VectorDocument; +use Symfony\AI\Store\Exception\RuntimeException; +use Symfony\Component\Uid\Uuid; +use Symfony\Contracts\Cache\CacheInterface; + +/** + * @author Guillaume Loulier + */ +final readonly class CacheStore implements VectorStoreInterface +{ + public function __construct( + private CacheInterface&CacheItemPoolInterface $cache, + private DistanceCalculator $distanceCalculator = new DistanceCalculator(), + private string $cacheKey = '_vectors', + ) { + if (!interface_exists(CacheItemPoolInterface::class)) { + throw new RuntimeException('For using the CacheStore as vector store, a PSR-6 cache implementation is required. Try running "composer require symfony/cache" or another PSR-6 compatible cache.'); + } + } + + public function add(VectorDocument ...$documents): void + { + $existingVectors = $this->cache->get($this->cacheKey, static fn (): array => []); + + $newVectors = array_map(static fn (VectorDocument $document): array => [ + 'id' => $document->id->toRfc4122(), + 'vector' => $document->vector->getData(), + 'metadata' => $document->metadata->getArrayCopy(), + ], $documents); + + $cacheItem = $this->cache->getItem($this->cacheKey); + + $cacheItem->set([ + ...$existingVectors, + ...$newVectors, + ]); + + $this->cache->save($cacheItem); + } + + /** + * @param array{ + * maxItems?: positive-int + * } $options If maxItems is provided, only the top N results will be returned + */ + public function query(Vector $vector, array $options = []): array + { + $documents = $this->cache->getItem($this->cacheKey)->get() ?? []; + + $vectorDocuments = array_map(static fn (array $document): VectorDocument => new VectorDocument( + id: Uuid::fromString($document['id']), + vector: new Vector($document['vector']), + metadata: new Metadata($document['metadata']), + ), $documents); + + return $this->distanceCalculator->calculate($vectorDocuments, $vector, $options['maxItems'] ?? null); + } +} diff --git a/src/store/src/DistanceCalculator.php b/src/store/src/DistanceCalculator.php new file mode 100644 index 000000000..a60b8a813 --- /dev/null +++ b/src/store/src/DistanceCalculator.php @@ -0,0 +1,133 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Store; + +use Symfony\AI\Platform\Vector\Vector; +use Symfony\AI\Store\Document\VectorDocument; + +/** + * @author Guillaume Loulier + */ +final readonly class DistanceCalculator +{ + public function __construct( + private DistanceStrategy $strategy = DistanceStrategy::COSINE_DISTANCE, + ) { + } + + /** + * @param VectorDocument[] $documents + * @param ?int $maxItems If maxItems is provided, only the top N results will be returned + * + * @return VectorDocument[] + */ + public function calculate(array $documents, Vector $vector, ?int $maxItems = null): array + { + $strategy = match ($this->strategy) { + DistanceStrategy::COSINE_DISTANCE => $this->cosineDistance(...), + DistanceStrategy::ANGULAR_DISTANCE => $this->angularDistance(...), + DistanceStrategy::EUCLIDEAN_DISTANCE => $this->euclideanDistance(...), + DistanceStrategy::MANHATTAN_DISTANCE => $this->manhattanDistance(...), + DistanceStrategy::CHEBYSHEV_DISTANCE => $this->chebyshevDistance(...), + }; + + $currentEmbeddings = array_map( + static fn (VectorDocument $vectorDocument): array => [ + 'distance' => $strategy($vectorDocument, $vector), + 'document' => $vectorDocument, + ], + $documents, + ); + + usort( + $currentEmbeddings, + static fn (array $embedding, array $nextEmbedding): int => $embedding['distance'] <=> $nextEmbedding['distance'], + ); + + if (null !== $maxItems && $maxItems < \count($currentEmbeddings)) { + $currentEmbeddings = \array_slice($currentEmbeddings, 0, $maxItems); + } + + return array_map( + static fn (array $embedding): VectorDocument => $embedding['document'], + $currentEmbeddings, + ); + } + + private function cosineDistance(VectorDocument $embedding, Vector $against): float + { + return 1 - $this->cosineSimilarity($embedding, $against); + } + + private function cosineSimilarity(VectorDocument $embedding, Vector $against): float + { + $currentEmbeddingVectors = $embedding->vector->getData(); + + $dotProduct = array_sum(array: array_map( + static fn (float $a, float $b): float => $a * $b, + $currentEmbeddingVectors, + $against->getData(), + )); + + $currentEmbeddingLength = sqrt(array_sum(array_map( + static fn (float $value): float => $value ** 2, + $currentEmbeddingVectors, + ))); + + $againstLength = sqrt(array_sum(array_map( + static fn (float $value): float => $value ** 2, + $against->getData(), + ))); + + return fdiv($dotProduct, $currentEmbeddingLength * $againstLength); + } + + private function angularDistance(VectorDocument $embedding, Vector $against): float + { + $cosineSimilarity = $this->cosineSimilarity($embedding, $against); + + return fdiv(acos($cosineSimilarity), \M_PI); + } + + private function euclideanDistance(VectorDocument $embedding, Vector $against): float + { + return sqrt(array_sum(array_map( + static fn (float $a, float $b): float => ($a - $b) ** 2, + $embedding->vector->getData(), + $against->getData(), + ))); + } + + private function manhattanDistance(VectorDocument $embedding, Vector $against): float + { + return array_sum(array_map( + static fn (float $a, float $b): float => abs($a - $b), + $embedding->vector->getData(), + $against->getData(), + )); + } + + private function chebyshevDistance(VectorDocument $embedding, Vector $against): float + { + $embeddingsAsPower = array_map( + static fn (float $currentValue, float $againstValue): float => abs($currentValue - $againstValue), + $embedding->vector->getData(), + $against->getData(), + ); + + return array_reduce( + array: $embeddingsAsPower, + callback: static fn (float $value, float $current): float => max($value, $current), + initial: 0.0, + ); + } +} diff --git a/src/store/src/DistanceStrategy.php b/src/store/src/DistanceStrategy.php new file mode 100644 index 000000000..d9a0e8391 --- /dev/null +++ b/src/store/src/DistanceStrategy.php @@ -0,0 +1,24 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Store; + +/** + * @author Guillaume Loulier + */ +enum DistanceStrategy: string +{ + case COSINE_DISTANCE = 'cosine'; + case ANGULAR_DISTANCE = 'angular'; + case EUCLIDEAN_DISTANCE = 'euclidean'; + case MANHATTAN_DISTANCE = 'manhattan'; + case CHEBYSHEV_DISTANCE = 'chebyshev'; +} diff --git a/src/store/src/InMemoryStore.php b/src/store/src/InMemoryStore.php index bc9cb9e92..1fd0372e6 100644 --- a/src/store/src/InMemoryStore.php +++ b/src/store/src/InMemoryStore.php @@ -13,26 +13,19 @@ use Symfony\AI\Platform\Vector\Vector; use Symfony\AI\Store\Document\VectorDocument; -use Symfony\AI\Store\Exception\InvalidArgumentException; /** * @author Guillaume Loulier */ final class InMemoryStore implements VectorStoreInterface { - public const COSINE_DISTANCE = 'cosine'; - public const ANGULAR_DISTANCE = 'angular'; - public const EUCLIDEAN_DISTANCE = 'euclidean'; - public const MANHATTAN_DISTANCE = 'manhattan'; - public const CHEBYSHEV_DISTANCE = 'chebyshev'; - /** * @var VectorDocument[] */ private array $documents = []; public function __construct( - private readonly string $distance = self::COSINE_DISTANCE, + private readonly DistanceCalculator $distanceCalculator = new DistanceCalculator(), ) { } @@ -48,103 +41,6 @@ public function add(VectorDocument ...$documents): void */ public function query(Vector $vector, array $options = []): array { - $strategy = match ($this->distance) { - self::COSINE_DISTANCE => $this->cosineDistance(...), - self::ANGULAR_DISTANCE => $this->angularDistance(...), - self::EUCLIDEAN_DISTANCE => $this->euclideanDistance(...), - self::MANHATTAN_DISTANCE => $this->manhattanDistance(...), - self::CHEBYSHEV_DISTANCE => $this->chebyshevDistance(...), - default => throw new InvalidArgumentException(\sprintf('Unsupported distance metric "%s"', $this->distance)), - }; - - $currentEmbeddings = array_map( - static fn (VectorDocument $vectorDocument): array => [ - 'distance' => $strategy($vectorDocument, $vector), - 'document' => $vectorDocument, - ], - $this->documents, - ); - - usort( - $currentEmbeddings, - static fn (array $embedding, array $nextEmbedding): int => $embedding['distance'] <=> $nextEmbedding['distance'], - ); - - if (\array_key_exists('maxItems', $options) && $options['maxItems'] < \count($currentEmbeddings)) { - $currentEmbeddings = \array_slice($currentEmbeddings, 0, $options['maxItems']); - } - - return array_map( - static fn (array $embedding): VectorDocument => $embedding['document'], - $currentEmbeddings, - ); - } - - private function cosineDistance(VectorDocument $embedding, Vector $against): float - { - return 1 - $this->cosineSimilarity($embedding, $against); - } - - private function cosineSimilarity(VectorDocument $embedding, Vector $against): float - { - $currentEmbeddingVectors = $embedding->vector->getData(); - - $dotProduct = array_sum(array: array_map( - static fn (float $a, float $b): float => $a * $b, - $currentEmbeddingVectors, - $against->getData(), - )); - - $currentEmbeddingLength = sqrt(array_sum(array_map( - static fn (float $value): float => $value ** 2, - $currentEmbeddingVectors, - ))); - - $againstLength = sqrt(array_sum(array_map( - static fn (float $value): float => $value ** 2, - $against->getData(), - ))); - - return fdiv($dotProduct, $currentEmbeddingLength * $againstLength); - } - - private function angularDistance(VectorDocument $embedding, Vector $against): float - { - $cosineSimilarity = $this->cosineSimilarity($embedding, $against); - - return fdiv(acos($cosineSimilarity), \M_PI); - } - - private function euclideanDistance(VectorDocument $embedding, Vector $against): float - { - return sqrt(array_sum(array_map( - static fn (float $a, float $b): float => ($a - $b) ** 2, - $embedding->vector->getData(), - $against->getData(), - ))); - } - - private function manhattanDistance(VectorDocument $embedding, Vector $against): float - { - return array_sum(array_map( - static fn (float $a, float $b): float => abs($a - $b), - $embedding->vector->getData(), - $against->getData(), - )); - } - - private function chebyshevDistance(VectorDocument $embedding, Vector $against): float - { - $embeddingsAsPower = array_map( - static fn (float $currentValue, float $againstValue): float => abs($currentValue - $againstValue), - $embedding->vector->getData(), - $against->getData(), - ); - - return array_reduce( - array: $embeddingsAsPower, - callback: static fn (float $value, float $current): float => max($value, $current), - initial: 0.0, - ); + return $this->distanceCalculator->calculate($this->documents, $vector, $options['maxItems'] ?? null); } } diff --git a/src/store/tests/Bridge/ClickHouse/StoreTest.php b/src/store/tests/Bridge/ClickHouse/StoreTest.php new file mode 100644 index 000000000..d5c8ce422 --- /dev/null +++ b/src/store/tests/Bridge/ClickHouse/StoreTest.php @@ -0,0 +1,236 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Store\Tests\Bridge\ClickHouse; + +use PHPUnit\Framework\Attributes\CoversClass; +use PHPUnit\Framework\TestCase; +use Symfony\AI\Platform\Vector\Vector; +use Symfony\AI\Store\Bridge\ClickHouse\Store; +use Symfony\AI\Store\Document\Metadata; +use Symfony\AI\Store\Document\VectorDocument; +use Symfony\AI\Store\Exception\RuntimeException; +use Symfony\Component\HttpClient\MockHttpClient; +use Symfony\Component\HttpClient\Response\MockResponse; +use Symfony\Component\Uid\Uuid; + +#[CoversClass(Store::class)] +final class StoreTest extends TestCase +{ + public function testInitialize() + { + $expectedRequests = []; + + $httpClient = new MockHttpClient(function (string $method, string $url, array $options) use (&$expectedRequests) { + $expectedRequests[] = compact('method', 'url', 'options'); + + $expectedSql = 'CREATE TABLE IF NOT EXISTS test_table ( + id UUID, + metadata String, + embedding Array(Float32), + ) ENGINE = MergeTree() + ORDER BY id'; + + $this->assertSame('POST', $method); + $this->assertStringContainsString('?', $url); // Check that URL has query parameters + $this->assertEquals( + str_replace([' ', "\n", "\t"], '', $expectedSql), + str_replace([' ', "\n", "\t"], '', $options['query']['query']) + ); + + return new MockResponse(''); + }); + + $store = new Store($httpClient, 'test_db', 'test_table'); + + $store->initialize(); + + $this->assertCount(1, $expectedRequests); + } + + public function testAddSingleDocument() + { + $uuid = Uuid::v4(); + $document = new VectorDocument($uuid, new Vector([0.1, 0.2, 0.3]), new Metadata(['title' => 'Test Document'])); + + $expectedJsonData = json_encode([ + 'id' => $uuid->toRfc4122(), + 'metadata' => json_encode(['title' => 'Test Document']), + 'embedding' => [0.1, 0.2, 0.3], + ])."\n"; + + $httpClient = new MockHttpClient(function (string $method, string $url, array $options) use ($expectedJsonData) { + $this->assertSame('POST', $method); + $this->assertStringContainsString('?', $url); // Check that URL has query parameters + $this->assertSame($expectedJsonData, $options['body']); + $this->assertSame('INSERT INTO test_table FORMAT JSONEachRow', $options['query']['query']); + $this->assertSame('Content-Type: application/json', $options['headers'][0]); + + return new MockResponse('', ['http_code' => 200]); + }); + + $store = new Store($httpClient, 'test_db', 'test_table'); + + $store->add($document); + } + + public function testAddMultipleDocuments() + { + $uuid1 = Uuid::v4(); + $uuid2 = Uuid::v4(); + $document1 = new VectorDocument($uuid1, new Vector([0.1, 0.2, 0.3])); + $document2 = new VectorDocument($uuid2, new Vector([0.4, 0.5, 0.6]), new Metadata(['title' => 'Second'])); + + $expectedJsonData = json_encode([ + 'id' => $uuid1->toRfc4122(), + 'metadata' => json_encode([]), + 'embedding' => [0.1, 0.2, 0.3], + ])."\n".json_encode([ + 'id' => $uuid2->toRfc4122(), + 'metadata' => json_encode(['title' => 'Second']), + 'embedding' => [0.4, 0.5, 0.6], + ])."\n"; + + $httpClient = new MockHttpClient(function (string $method, string $url, array $options) use ($expectedJsonData) { + $this->assertSame('POST', $method); + $this->assertStringContainsString('?', $url); // Check that URL has query parameters + $this->assertSame($expectedJsonData, $options['body']); + + return new MockResponse('', ['http_code' => 200]); + }); + + $store = new Store($httpClient, 'test_db', 'test_table'); + + $store->add($document1, $document2); + } + + public function testAddThrowsExceptionOnHttpError() + { + $uuid = Uuid::v4(); + $document = new VectorDocument($uuid, new Vector([0.1, 0.2, 0.3])); + + $httpClient = new MockHttpClient(function (string $method, string $url, array $options) { + return new MockResponse('Internal Server Error', ['http_code' => 500]); + }); + + $store = new Store($httpClient, 'test_db', 'test_table'); + + $this->expectException(RuntimeException::class); + $this->expectExceptionMessage('Could not insert data into ClickHouse. Http status code: 500. Response: Internal Server Error.'); + + $store->add($document); + } + + public function testQuery() + { + $queryVector = new Vector([0.1, 0.2, 0.3]); + $uuid = Uuid::v4(); + + $httpClient = new MockHttpClient(function (string $method, string $url, array $options) { + $this->assertSame('GET', $method); + $this->assertStringContainsString('?', $url); // Check that URL has query parameters + $this->assertSame('[0.1,0.2,0.3]', $options['query']['param_query_vector']); + $this->assertSame(5, $options['query']['param_limit']); + + return new MockResponse(json_encode([ + 'data' => [ + [ + 'id' => '01234567-89ab-cdef-0123-456789abcdef', + 'embedding' => [0.1, 0.2, 0.3], + 'metadata' => json_encode(['title' => 'Test Document']), + 'score' => 0.95, + ], + ], + ])); + }); + + $store = new Store($httpClient, 'test_db', 'test_table'); + + $results = $store->query($queryVector); + + $this->assertCount(1, $results); + $this->assertInstanceOf(VectorDocument::class, $results[0]); + $this->assertSame(0.95, $results[0]->score); + $this->assertSame(['title' => 'Test Document'], $results[0]->metadata->getArrayCopy()); + } + + public function testQueryWithOptions() + { + $queryVector = new Vector([0.1, 0.2, 0.3]); + + $httpClient = new MockHttpClient(function (string $method, string $url, array $options) { + $this->assertSame('GET', $method); + $this->assertStringContainsString('?', $url); // Check that URL has query parameters + $this->assertSame(10, $options['query']['param_limit']); + $this->assertSame('test_value', $options['query']['param_custom_param']); + + return new MockResponse(json_encode(['data' => []])); + }); + + $store = new Store($httpClient, 'test_db', 'test_table'); + + $results = $store->query($queryVector, [ + 'limit' => 10, + 'params' => ['custom_param' => 'test_value'], + ]); + + $this->assertCount(0, $results); + } + + public function testQueryWithWhereClause() + { + $queryVector = new Vector([0.1, 0.2, 0.3]); + + $httpClient = new MockHttpClient(function (string $method, string $url, array $options) { + $this->assertSame('GET', $method); + $this->assertStringContainsString('?', $url); // Check that URL has query parameters + $this->assertStringContainsString("AND JSONExtractString(metadata, 'type') = 'document'", $options['query']['query']); + + return new MockResponse(json_encode(['data' => []])); + }); + + $store = new Store($httpClient, 'test_db', 'test_table'); + + $results = $store->query($queryVector, [ + 'where' => "JSONExtractString(metadata, 'type') = 'document'", + ]); + + $this->assertCount(0, $results); + } + + public function testQueryWithNullMetadata() + { + $queryVector = new Vector([0.1, 0.2, 0.3]); + $uuid = Uuid::v4(); + + $responseData = [ + 'data' => [ + [ + 'id' => $uuid->toRfc4122(), + 'embedding' => [0.1, 0.2, 0.3], + 'metadata' => null, + 'score' => 0.95, + ], + ], + ]; + + $httpClient = new MockHttpClient(function (string $method, string $url, array $options) use ($responseData) { + return new MockResponse(json_encode($responseData)); + }); + + $store = new Store($httpClient, 'test_db', 'test_table'); + + $results = $store->query($queryVector); + + $this->assertCount(1, $results); + $this->assertSame([], $results[0]->metadata->getArrayCopy()); + } +} diff --git a/src/store/tests/Bridge/Meilisearch/StoreTest.php b/src/store/tests/Bridge/Meilisearch/StoreTest.php index 7812a9ddd..990825714 100644 --- a/src/store/tests/Bridge/Meilisearch/StoreTest.php +++ b/src/store/tests/Bridge/Meilisearch/StoreTest.php @@ -44,9 +44,9 @@ public function testStoreCannotInitializeOnInvalidResponse() 'test', ); - self::expectException(ClientException::class); - self::expectExceptionMessage('HTTP 400 returned for "http://localhost:7700/indexes".'); - self::expectExceptionCode(400); + $this->expectException(ClientException::class); + $this->expectExceptionMessage('HTTP 400 returned for "http://localhost:7700/indexes".'); + $this->expectExceptionCode(400); $store->initialize(); } @@ -105,9 +105,9 @@ public function testStoreCannotAddOnInvalidResponse() 'test', ); - self::expectException(ClientException::class); - self::expectExceptionMessage('HTTP 400 returned for "http://localhost:7700/indexes/test/documents".'); - self::expectExceptionCode(400); + $this->expectException(ClientException::class); + $this->expectExceptionMessage('HTTP 400 returned for "http://localhost:7700/indexes/test/documents".'); + $this->expectExceptionCode(400); $store->add(new VectorDocument(Uuid::v4(), new Vector([0.1, 0.2, 0.3]))); } @@ -157,9 +157,9 @@ public function testStoreCannotQueryOnInvalidResponse() 'test', ); - self::expectException(ClientException::class); - self::expectExceptionMessage('HTTP 400 returned for "http://localhost:7700/indexes/test/search".'); - self::expectExceptionCode(400); + $this->expectException(ClientException::class); + $this->expectExceptionMessage('HTTP 400 returned for "http://localhost:7700/indexes/test/search".'); + $this->expectExceptionCode(400); $store->query(new Vector([0.1, 0.2, 0.3])); } diff --git a/src/store/tests/Bridge/Neo4j/StoreTest.php b/src/store/tests/Bridge/Neo4j/StoreTest.php index 698c71379..946b09231 100644 --- a/src/store/tests/Bridge/Neo4j/StoreTest.php +++ b/src/store/tests/Bridge/Neo4j/StoreTest.php @@ -24,7 +24,7 @@ #[CoversClass(Store::class)] final class StoreTest extends TestCase { - public function testStoreCannotInitializeOnInvalidResponse(): void + public function testStoreCannotInitializeOnInvalidResponse() { $httpClient = new MockHttpClient([ new JsonMockResponse([], [ @@ -34,13 +34,13 @@ public function testStoreCannotInitializeOnInvalidResponse(): void $store = new Store($httpClient, 'http://localhost:7474', 'symfony', 'symfony', 'symfony', 'symfony', 'symfony'); - self::expectException(ClientException::class); - self::expectExceptionMessage('HTTP 400 returned for "http://localhost:7474/db/symfony/query/v2".'); - self::expectExceptionCode(400); + $this->expectException(ClientException::class); + $this->expectExceptionMessage('HTTP 400 returned for "http://localhost:7474/db/symfony/query/v2".'); + $this->expectExceptionCode(400); $store->initialize(); } - public function testStoreCanInitialize(): void + public function testStoreCanInitialize() { $httpClient = new MockHttpClient([ new JsonMockResponse([ @@ -85,7 +85,7 @@ public function testStoreCanInitialize(): void $this->assertSame(2, $httpClient->getRequestsCount()); } - public function testStoreCanAdd(): void + public function testStoreCanAdd() { $httpClient = new MockHttpClient([ new JsonMockResponse([ @@ -164,7 +164,7 @@ public function testStoreCanAdd(): void $this->assertSame(3, $httpClient->getRequestsCount()); } - public function testStoreCanQuery(): void + public function testStoreCanQuery() { $httpClient = new MockHttpClient([ new JsonMockResponse([ diff --git a/src/store/tests/Bridge/Qdrant/StoreTest.php b/src/store/tests/Bridge/Qdrant/StoreTest.php index 2e317adfe..e685a0c00 100644 --- a/src/store/tests/Bridge/Qdrant/StoreTest.php +++ b/src/store/tests/Bridge/Qdrant/StoreTest.php @@ -42,9 +42,9 @@ public function testStoreCannotInitializeOnInvalidResponse() $store = new Store($httpClient, 'http://localhost:6333', 'test', 'test'); - self::expectException(ClientException::class); - self::expectExceptionMessage('HTTP 400 returned for "http://localhost:6333/collections/test".'); - self::expectExceptionCode(400); + $this->expectException(ClientException::class); + $this->expectExceptionMessage('HTTP 400 returned for "http://localhost:6333/collections/test".'); + $this->expectExceptionCode(400); $store->initialize(); } @@ -110,9 +110,9 @@ public function testStoreCannotAddOnInvalidResponse() $store = new Store($httpClient, 'http://localhost:6333', 'test', 'test'); - self::expectException(ClientException::class); - self::expectExceptionMessage('HTTP 400 returned for "http://localhost:6333/collections/test/points".'); - self::expectExceptionCode(400); + $this->expectException(ClientException::class); + $this->expectExceptionMessage('HTTP 400 returned for "http://localhost:6333/collections/test/points".'); + $this->expectExceptionCode(400); $store->add(new VectorDocument(Uuid::v4(), new Vector([0.1, 0.2, 0.3]))); } @@ -153,9 +153,9 @@ public function testStoreCannotQueryOnInvalidResponse() 'test', ); - self::expectException(ClientException::class); - self::expectExceptionMessage('HTTP 400 returned for "http://localhost:6333/collections/test/points/query".'); - self::expectExceptionCode(400); + $this->expectException(ClientException::class); + $this->expectExceptionMessage('HTTP 400 returned for "http://localhost:6333/collections/test/points/query".'); + $this->expectExceptionCode(400); $store->query(new Vector([0.1, 0.2, 0.3])); } diff --git a/src/store/tests/Bridge/SurrealDb/StoreTest.php b/src/store/tests/Bridge/SurrealDb/StoreTest.php index 5a92ea0a3..e4f9ee10b 100644 --- a/src/store/tests/Bridge/SurrealDb/StoreTest.php +++ b/src/store/tests/Bridge/SurrealDb/StoreTest.php @@ -9,7 +9,7 @@ * file that was distributed with this source code. */ -namespace Bridge\SurrealDb; +namespace Symfony\AI\Store\Tests\Bridge\SurrealDb; use PHPUnit\Framework\Attributes\CoversClass; use PHPUnit\Framework\TestCase; @@ -34,9 +34,9 @@ public function testStoreCannotInitializeOnInvalidResponse() $store = new Store($httpClient, 'http://localhost:8000', 'test', 'test', 'test', 'test'); - self::expectException(ClientException::class); - self::expectExceptionMessage('HTTP 400 returned for "http://localhost:8000/signin".'); - self::expectExceptionCode(400); + $this->expectException(ClientException::class); + $this->expectExceptionMessage('HTTP 400 returned for "http://localhost:8000/signin".'); + $this->expectExceptionCode(400); $store->initialize(); } @@ -57,9 +57,9 @@ public function testStoreCannotInitializeOnValidAuthenticationResponse() $store = new Store($httpClient, 'http://localhost:8000', 'test', 'test', 'test', 'test'); - self::expectException(ClientException::class); - self::expectExceptionMessage('HTTP 400 returned for "http://localhost:8000/sql".'); - self::expectExceptionCode(400); + $this->expectException(ClientException::class); + $this->expectExceptionMessage('HTTP 400 returned for "http://localhost:8000/sql".'); + $this->expectExceptionCode(400); $store->initialize(); } @@ -118,9 +118,9 @@ public function testStoreCannotAddOnInvalidResponse() $store = new Store($httpClient, 'http://localhost:8000', 'test', 'test', 'test', 'test', 'test'); $store->initialize(); - self::expectException(ClientException::class); - self::expectExceptionMessage('HTTP 400 returned for "http://localhost:8000/key/test".'); - self::expectExceptionCode(400); + $this->expectException(ClientException::class); + $this->expectExceptionMessage('HTTP 400 returned for "http://localhost:8000/key/test".'); + $this->expectExceptionCode(400); $store->add(new VectorDocument(Uuid::v4(), new Vector([0.1, 0.2, 0.3]))); } @@ -151,9 +151,9 @@ public function testStoreCannotAddOnInvalidAddResponse() $store = new Store($httpClient, 'http://localhost:8000', 'test', 'test', 'test', 'test', 'test'); $store->initialize(); - self::expectException(ClientException::class); - self::expectExceptionMessage('HTTP 400 returned for "http://localhost:8000/key/test".'); - self::expectExceptionCode(400); + $this->expectException(ClientException::class); + $this->expectExceptionMessage('HTTP 400 returned for "http://localhost:8000/key/test".'); + $this->expectExceptionCode(400); $store->add(new VectorDocument(Uuid::v4(), new Vector(array_fill(0, 1275, 0.1)))); } @@ -263,9 +263,9 @@ public function testStoreCannotQueryOnInvalidResponse() $store->add(new VectorDocument(Uuid::v4(), new Vector(array_fill(0, 1275, 0.1)))); - self::expectException(ClientException::class); - self::expectExceptionMessage('HTTP 400 returned for "http://localhost:8000/sql".'); - self::expectExceptionCode(400); + $this->expectException(ClientException::class); + $this->expectExceptionMessage('HTTP 400 returned for "http://localhost:8000/sql".'); + $this->expectExceptionCode(400); $store->query(new Vector(array_fill(0, 1275, 0.1))); } diff --git a/src/store/tests/Bridge/Typesense/StoreTest.php b/src/store/tests/Bridge/Typesense/StoreTest.php new file mode 100644 index 000000000..65a09cc7d --- /dev/null +++ b/src/store/tests/Bridge/Typesense/StoreTest.php @@ -0,0 +1,181 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Store\Tests\Bridge\Typesense; + +use PHPUnit\Framework\Attributes\CoversClass; +use PHPUnit\Framework\Attributes\UsesClass; +use PHPUnit\Framework\TestCase; +use Symfony\AI\Platform\Vector\Vector; +use Symfony\AI\Store\Bridge\Typesense\Store; +use Symfony\AI\Store\Document\VectorDocument; +use Symfony\Component\HttpClient\Exception\ClientException; +use Symfony\Component\HttpClient\MockHttpClient; +use Symfony\Component\HttpClient\Response\JsonMockResponse; +use Symfony\Component\Uid\Uuid; + +#[CoversClass(Store::class)] +#[UsesClass(VectorDocument::class)] +#[UsesClass(Vector::class)] +final class StoreTest extends TestCase +{ + public function testStoreCannotInitializeOnExistingCollection() + { + $httpClient = new MockHttpClient([ + new JsonMockResponse([ + 'message' => 'A collection with name "test" already exists.', + ], [ + 'http_code' => 400, + ]), + ], 'http://localhost:8108'); + + $store = new Store( + $httpClient, + 'http://localhost:8108', + 'test', + 'test', + ); + + $this->expectException(ClientException::class); + $this->expectExceptionMessage('HTTP 400 returned for "http://localhost:8108/collections".'); + $this->expectExceptionCode(400); + $store->initialize(); + } + + public function testStoreCanInitialize() + { + $httpClient = new MockHttpClient([ + new JsonMockResponse([ + 'name' => 'test', + 'num_documents' => 0, + ], [ + 'http_code' => 200, + ]), + ], 'http://localhost:8108'); + + $store = new Store( + $httpClient, + 'http://localhost:8108', + 'test', + 'test', + ); + + $store->initialize(); + + $this->assertSame(1, $httpClient->getRequestsCount()); + } + + public function testStoreCannotAddOnInvalidResponse() + { + $httpClient = new MockHttpClient([ + new JsonMockResponse([], [ + 'http_code' => 400, + ]), + ], 'http://localhost:8108'); + + $store = new Store( + $httpClient, + 'http://localhost:8108', + 'test', + 'test', + ); + + $this->expectException(ClientException::class); + $this->expectExceptionMessage('HTTP 400 returned for "http://localhost:8108/collections/test/documents".'); + $this->expectExceptionCode(400); + $store->add(new VectorDocument(Uuid::v4(), new Vector([0.1, 0.2, 0.3]))); + } + + public function testStoreCanAdd() + { + $httpClient = new MockHttpClient([ + new JsonMockResponse([], [ + 'http_code' => 200, + ]), + ], 'http://localhost:8108'); + + $store = new Store( + $httpClient, + 'http://localhost:8108', + 'test', + 'test', + ); + + $store->add(new VectorDocument(Uuid::v4(), new Vector([0.1, 0.2, 0.3]))); + + $this->assertSame(1, $httpClient->getRequestsCount()); + } + + public function testStoreCannotQueryOnInvalidResponse() + { + $httpClient = new MockHttpClient([ + new JsonMockResponse([], [ + 'http_code' => 400, + ]), + ], 'http://localhost:8108'); + + $store = new Store( + $httpClient, + 'http://localhost:8108', + 'test', + 'test', + ); + + $this->expectException(ClientException::class); + $this->expectExceptionMessage('HTTP 400 returned for "http://localhost:8108/multi_search".'); + $this->expectExceptionCode(400); + $store->query(new Vector([0.1, 0.2, 0.3])); + } + + public function testStoreCanQuery() + { + $httpClient = new MockHttpClient([ + new JsonMockResponse([ + 'results' => [ + [ + 'hits' => [ + [ + 'document' => [ + 'id' => Uuid::v4()->toRfc4122(), + 'vector' => [0.1, 0.2, 0.3], + 'metadata' => '{"foo":"bar"}', + ], + 'vector_distance' => 1.0, + ], + [ + 'document' => [ + 'id' => Uuid::v4()->toRfc4122(), + 'vector' => [0.1, 0.2, 0.3], + 'metadata' => '{"foo":"bar"}', + ], + 'vector_distance' => 1.0, + ], + ], + ], + ], + ], [ + 'http_code' => 200, + ]), + ], 'http://localhost:8108'); + + $store = new Store( + $httpClient, + 'http://localhost:8108', + 'test', + 'test', + ); + + $results = $store->query(new Vector([0.1, 0.2, 0.3])); + + $this->assertCount(2, $results); + $this->assertSame(1, $httpClient->getRequestsCount()); + } +} diff --git a/src/store/tests/CacheStoreTest.php b/src/store/tests/CacheStoreTest.php new file mode 100644 index 000000000..c0b8bc6b0 --- /dev/null +++ b/src/store/tests/CacheStoreTest.php @@ -0,0 +1,143 @@ + + * + * For the full copyright and license information, please view the LICENSE + * file that was distributed with this source code. + */ + +namespace Symfony\AI\Store\Tests; + +use PHPUnit\Framework\Attributes\CoversClass; +use PHPUnit\Framework\Attributes\UsesClass; +use PHPUnit\Framework\TestCase; +use Symfony\AI\Platform\Vector\Vector; +use Symfony\AI\Store\CacheStore; +use Symfony\AI\Store\DistanceCalculator; +use Symfony\AI\Store\DistanceStrategy; +use Symfony\AI\Store\Document\VectorDocument; +use Symfony\Component\Cache\Adapter\ArrayAdapter; +use Symfony\Component\Uid\Uuid; + +#[CoversClass(CacheStore::class)] +#[UsesClass(VectorDocument::class)] +#[UsesClass(Vector::class)] +final class CacheStoreTest extends TestCase +{ + public function testStoreCanSearchUsingCosineDistance() + { + $store = new CacheStore(new ArrayAdapter()); + $store->add( + new VectorDocument(Uuid::v4(), new Vector([0.1, 0.1, 0.5])), + new VectorDocument(Uuid::v4(), new Vector([0.7, -0.3, 0.0])), + new VectorDocument(Uuid::v4(), new Vector([0.3, 0.7, 0.1])), + ); + + $result = $store->query(new Vector([0.0, 0.1, 0.6])); + $this->assertCount(3, $result); + $this->assertSame([0.1, 0.1, 0.5], $result[0]->vector->getData()); + + $store->add( + new VectorDocument(Uuid::v4(), new Vector([0.1, 0.1, 0.5])), + new VectorDocument(Uuid::v4(), new Vector([0.7, -0.3, 0.0])), + new VectorDocument(Uuid::v4(), new Vector([0.3, 0.7, 0.1])), + ); + + $result = $store->query(new Vector([0.0, 0.1, 0.6])); + $this->assertCount(6, $result); + $this->assertSame([0.1, 0.1, 0.5], $result[0]->vector->getData()); + } + + public function testStoreCanSearchUsingCosineDistanceAndReturnCorrectOrder() + { + $store = new CacheStore(new ArrayAdapter()); + $store->add( + new VectorDocument(Uuid::v4(), new Vector([0.1, 0.1, 0.5])), + new VectorDocument(Uuid::v4(), new Vector([0.7, -0.3, 0.0])), + new VectorDocument(Uuid::v4(), new Vector([0.3, 0.7, 0.1])), + new VectorDocument(Uuid::v4(), new Vector([0.3, 0.1, 0.6])), + new VectorDocument(Uuid::v4(), new Vector([0.0, 0.1, 0.6])), + ); + + $result = $store->query(new Vector([0.0, 0.1, 0.6])); + $this->assertCount(5, $result); + $this->assertSame([0.0, 0.1, 0.6], $result[0]->vector->getData()); + $this->assertSame([0.1, 0.1, 0.5], $result[1]->vector->getData()); + $this->assertSame([0.3, 0.1, 0.6], $result[2]->vector->getData()); + $this->assertSame([0.3, 0.7, 0.1], $result[3]->vector->getData()); + $this->assertSame([0.7, -0.3, 0.0], $result[4]->vector->getData()); + } + + public function testStoreCanSearchUsingCosineDistanceWithMaxItems() + { + $store = new CacheStore(new ArrayAdapter()); + $store->add( + new VectorDocument(Uuid::v4(), new Vector([0.1, 0.1, 0.5])), + new VectorDocument(Uuid::v4(), new Vector([0.7, -0.3, 0.0])), + new VectorDocument(Uuid::v4(), new Vector([0.3, 0.7, 0.1])), + ); + + $this->assertCount(1, $store->query(new Vector([0.0, 0.1, 0.6]), [ + 'maxItems' => 1, + ])); + } + + public function testStoreCanSearchUsingAngularDistance() + { + $store = new CacheStore(new ArrayAdapter(), new DistanceCalculator(DistanceStrategy::ANGULAR_DISTANCE)); + $store->add( + new VectorDocument(Uuid::v4(), new Vector([1.0, 2.0, 3.0])), + new VectorDocument(Uuid::v4(), new Vector([1.0, 5.0, 7.0])), + ); + + $result = $store->query(new Vector([1.2, 2.3, 3.4])); + + $this->assertCount(2, $result); + $this->assertSame([1.0, 2.0, 3.0], $result[0]->vector->getData()); + } + + public function testStoreCanSearchUsingEuclideanDistance() + { + $store = new CacheStore(new ArrayAdapter(), new DistanceCalculator(DistanceStrategy::EUCLIDEAN_DISTANCE)); + $store->add( + new VectorDocument(Uuid::v4(), new Vector([1.0, 5.0, 7.0])), + new VectorDocument(Uuid::v4(), new Vector([1.0, 2.0, 3.0])), + ); + + $result = $store->query(new Vector([1.2, 2.3, 3.4])); + + $this->assertCount(2, $result); + $this->assertSame([1.0, 2.0, 3.0], $result[0]->vector->getData()); + } + + public function testStoreCanSearchUsingManhattanDistance() + { + $store = new CacheStore(new ArrayAdapter(), new DistanceCalculator(DistanceStrategy::MANHATTAN_DISTANCE)); + $store->add( + new VectorDocument(Uuid::v4(), new Vector([1.0, 2.0, 3.0])), + new VectorDocument(Uuid::v4(), new Vector([1.0, 5.0, 7.0])), + ); + + $result = $store->query(new Vector([1.2, 2.3, 3.4])); + + $this->assertCount(2, $result); + $this->assertSame([1.0, 2.0, 3.0], $result[0]->vector->getData()); + } + + public function testStoreCanSearchUsingChebyshevDistance() + { + $store = new CacheStore(new ArrayAdapter(), new DistanceCalculator(DistanceStrategy::CHEBYSHEV_DISTANCE)); + $store->add( + new VectorDocument(Uuid::v4(), new Vector([1.0, 2.0, 3.0])), + new VectorDocument(Uuid::v4(), new Vector([1.0, 5.0, 7.0])), + ); + + $result = $store->query(new Vector([1.2, 2.3, 3.4])); + + $this->assertCount(2, $result); + $this->assertSame([1.0, 2.0, 3.0], $result[0]->vector->getData()); + } +} diff --git a/src/store/tests/Document/Loader/TextFileLoaderTest.php b/src/store/tests/Document/Loader/TextFileLoaderTest.php index adf81dcf3..2ebf5f6f0 100644 --- a/src/store/tests/Document/Loader/TextFileLoaderTest.php +++ b/src/store/tests/Document/Loader/TextFileLoaderTest.php @@ -24,8 +24,8 @@ public function testLoadWithInvalidSource() { $loader = new TextFileLoader(); - self::expectException(RuntimeException::class); - self::expectExceptionMessage('File "/invalid/source.txt" does not exist.'); + $this->expectException(RuntimeException::class); + $this->expectExceptionMessage('File "/invalid/source.txt" does not exist.'); iterator_to_array($loader('/invalid/source.txt')); } diff --git a/src/store/tests/Document/Transformer/TextSplitTransformerTest.php b/src/store/tests/Document/Transformer/TextSplitTransformerTest.php index 979f17f88..fc0c108ae 100644 --- a/src/store/tests/Document/Transformer/TextSplitTransformerTest.php +++ b/src/store/tests/Document/Transformer/TextSplitTransformerTest.php @@ -163,8 +163,8 @@ public function testSplitWithChunkSizeLargerThanText() public function testSplitWithOverlapGreaterThanChunkSize() { $document = new TextDocument(Uuid::v4(), 'Abcdefg', new Metadata([])); - self::expectException(InvalidArgumentException::class); - self::expectExceptionMessage('Overlap must be non-negative and less than chunk size.'); + $this->expectException(InvalidArgumentException::class); + $this->expectExceptionMessage('Overlap must be non-negative and less than chunk size.'); iterator_to_array(($this->transformer)([$document], [ TextSplitTransformer::OPTION_CHUNK_SIZE => 10, @@ -175,8 +175,8 @@ public function testSplitWithOverlapGreaterThanChunkSize() public function testSplitWithNegativeOverlap() { $document = new TextDocument(Uuid::v4(), 'Abcdefg', new Metadata([])); - self::expectException(InvalidArgumentException::class); - self::expectExceptionMessage('Overlap must be non-negative and less than chunk size.'); + $this->expectException(InvalidArgumentException::class); + $this->expectExceptionMessage('Overlap must be non-negative and less than chunk size.'); iterator_to_array(($this->transformer)([$document], [ TextSplitTransformer::OPTION_CHUNK_SIZE => 10, diff --git a/src/store/tests/InMemoryStoreTest.php b/src/store/tests/InMemoryStoreTest.php index 116bab423..6c155539f 100644 --- a/src/store/tests/InMemoryStoreTest.php +++ b/src/store/tests/InMemoryStoreTest.php @@ -12,13 +12,18 @@ namespace Symfony\AI\Store\Tests; use PHPUnit\Framework\Attributes\CoversClass; +use PHPUnit\Framework\Attributes\UsesClass; use PHPUnit\Framework\TestCase; use Symfony\AI\Platform\Vector\Vector; +use Symfony\AI\Store\DistanceCalculator; +use Symfony\AI\Store\DistanceStrategy; use Symfony\AI\Store\Document\VectorDocument; use Symfony\AI\Store\InMemoryStore; use Symfony\Component\Uid\Uuid; #[CoversClass(InMemoryStore::class)] +#[UsesClass(VectorDocument::class)] +#[UsesClass(Vector::class)] final class InMemoryStoreTest extends TestCase { public function testStoreCanSearchUsingCosineDistance() @@ -81,7 +86,7 @@ public function testStoreCanSearchUsingCosineDistanceWithMaxItems() public function testStoreCanSearchUsingAngularDistance() { - $store = new InMemoryStore(InMemoryStore::ANGULAR_DISTANCE); + $store = new InMemoryStore(new DistanceCalculator(DistanceStrategy::ANGULAR_DISTANCE)); $store->add( new VectorDocument(Uuid::v4(), new Vector([1.0, 2.0, 3.0])), new VectorDocument(Uuid::v4(), new Vector([1.0, 5.0, 7.0])), @@ -95,7 +100,7 @@ public function testStoreCanSearchUsingAngularDistance() public function testStoreCanSearchUsingEuclideanDistance() { - $store = new InMemoryStore(InMemoryStore::EUCLIDEAN_DISTANCE); + $store = new InMemoryStore(new DistanceCalculator(DistanceStrategy::EUCLIDEAN_DISTANCE)); $store->add( new VectorDocument(Uuid::v4(), new Vector([1.0, 5.0, 7.0])), new VectorDocument(Uuid::v4(), new Vector([1.0, 2.0, 3.0])), @@ -109,7 +114,7 @@ public function testStoreCanSearchUsingEuclideanDistance() public function testStoreCanSearchUsingManhattanDistance() { - $store = new InMemoryStore(InMemoryStore::MANHATTAN_DISTANCE); + $store = new InMemoryStore(new DistanceCalculator(DistanceStrategy::MANHATTAN_DISTANCE)); $store->add( new VectorDocument(Uuid::v4(), new Vector([1.0, 2.0, 3.0])), new VectorDocument(Uuid::v4(), new Vector([1.0, 5.0, 7.0])), @@ -123,7 +128,7 @@ public function testStoreCanSearchUsingManhattanDistance() public function testStoreCanSearchUsingChebyshevDistance() { - $store = new InMemoryStore(InMemoryStore::CHEBYSHEV_DISTANCE); + $store = new InMemoryStore(new DistanceCalculator(DistanceStrategy::CHEBYSHEV_DISTANCE)); $store->add( new VectorDocument(Uuid::v4(), new Vector([1.0, 2.0, 3.0])), new VectorDocument(Uuid::v4(), new Vector([1.0, 5.0, 7.0])),