diff --git a/docs/.vitepress/config.mts b/docs/.vitepress/config.mts index 570b169..42ba3cf 100644 --- a/docs/.vitepress/config.mts +++ b/docs/.vitepress/config.mts @@ -102,6 +102,10 @@ export default defineConfig({ text: "Tool & Function Calling", link: "/core-concepts/tools-function-calling", }, + { + text: "Embeddings", + link: "/core-concepts/embeddings", + }, { text: "Prism Server", link: "/core-concepts/prism-server", diff --git a/docs/core-concepts/embeddings.md b/docs/core-concepts/embeddings.md new file mode 100644 index 0000000..2fcbc95 --- /dev/null +++ b/docs/core-concepts/embeddings.md @@ -0,0 +1,107 @@ +# Embeddings + +Transform your text into powerful vector representations! Embeddings let you add semantic search, recommendation systems, and other advanced natural language features to your applications. + +## Quick Start + +Here's how to generate embeddings with just a few lines of code: + +```php +use EchoLabs\Prism\Facades\Prism; +use EchoLabs\Prism\Enums\Provider; + +$response = Prism::embeddings() + ->using(Provider::OpenAI, 'text-embedding-3-large') + ->fromInput('Your text goes here') + ->generate(); + +// Get your embeddings vector +$embeddings = $response->embeddings; + +// Check token usage +echo $response->usage->tokens; +``` + +## Input Methods + +You've got two convenient ways to feed text into the embeddings generator: + +### Direct Text Input + +```php +$response = Prism::embeddings() + ->using(Provider::OpenAI, 'text-embedding-3-large') + ->fromInput('Analyze this text') + ->generate(); +``` + +### From File + +Need to analyze a larger document? No problem: + +```php +$response = Prism::embeddings() + ->using(Provider::OpenAI, 'text-embedding-3-large') + ->fromFile('/path/to/your/document.txt') + ->generate(); +``` + +> [!NOTE] +> Make sure your file exists and is readable. The generator will throw a helpful `PrismException` if there's any issue accessing the file. + +## Common Settings + +Just like with text generation, you can fine-tune your embeddings requests: + +```php +$response = Prism::embeddings() + ->using(Provider::OpenAI, 'text-embedding-3-large') + ->fromInput('Your text here') + ->withClientOptions(['timeout' => 30]) // Adjust request timeout + ->withClientRetry(3, 100) // Add automatic retries + ->generate(); +``` + +## Response Handling + +The embeddings response gives you everything you need: + +```php +// Get the embeddings vector +$vector = $response->embeddings; + +// Check token usage +$tokenCount = $response->usage->tokens; +``` + +## Error Handling + +Always handle potential errors gracefully: + +```php +use EchoLabs\Prism\Exceptions\PrismException; + +try { + $response = Prism::embeddings() + ->using(Provider::OpenAI, 'text-embedding-3-large') + ->fromInput('Your text here') + ->generate(); +} catch (PrismException $e) { + Log::error('Embeddings generation failed:', [ + 'error' => $e->getMessage() + ]); +} +``` + +## Pro Tips 🌟 + +**Vector Storage**: Consider using a vector database like Milvus, Qdrant, or pgvector to store and query your embeddings efficiently. + +**Text Preprocessing**: For best results, clean and normalize your text before generating embeddings. This might include: + - Removing unnecessary whitespace + - Converting to lowercase + - Removing special characters + - Handling Unicode normalization + +> [!IMPORTANT] +> Different providers and models produce vectors of different dimensions. Always check your provider's documentation for specific details about the embedding model you're using. diff --git a/src/Contracts/Provider.php b/src/Contracts/Provider.php index 4063fc4..8e9c2a4 100644 --- a/src/Contracts/Provider.php +++ b/src/Contracts/Provider.php @@ -4,6 +4,8 @@ namespace EchoLabs\Prism\Contracts; +use EchoLabs\Prism\Embeddings\Request as EmbeddingsRequest; +use EchoLabs\Prism\Embeddings\Response as EmbeddingsResponse; use EchoLabs\Prism\Providers\ProviderResponse; use EchoLabs\Prism\Structured\Request as StructuredRequest; use EchoLabs\Prism\Text\Request as TextRequest; @@ -13,4 +15,6 @@ interface Provider public function text(TextRequest $request): ProviderResponse; public function structured(StructuredRequest $request): ProviderResponse; + + public function embeddings(EmbeddingsRequest $request): EmbeddingsResponse; } diff --git a/src/Embeddings/Generator.php b/src/Embeddings/Generator.php new file mode 100644 index 0000000..32afb31 --- /dev/null +++ b/src/Embeddings/Generator.php @@ -0,0 +1,74 @@ + */ + protected array $clientOptions = []; + + /** @var array{0: array|int, 1?: Closure|int, 2?: ?callable, 3?: bool} */ + protected array $clientRetry = [0]; + + protected Provider $provider; + + protected string $model; + + public function using(string|ProviderEnum $provider, string $model): self + { + $this->provider = resolve(PrismManager::class) + ->resolve($provider); + + $this->model = $model; + + return $this; + } + + public function fromInput(string $input): self + { + $this->input = $input; + + return $this; + } + + public function fromFile(string $path): self + { + if (! is_file($path)) { + throw new PrismException(sprintf('%s is not a valid file', $path)); + } + + $contents = file_get_contents($path); + + if ($contents === false) { + throw new PrismException(sprintf('%s contents could not be read', $path)); + } + + $this->input = $contents; + + return $this; + } + + public function generate(): Response + { + if ($this->input === '' || $this->input === '0') { + throw new PrismException('Embeddings input is required'); + } + + return $this->provider->embeddings(new Request( + model: $this->model, + input: $this->input, + clientOptions: $this->clientOptions, + clientRetry: $this->clientRetry, + )); + } +} diff --git a/src/Embeddings/Request.php b/src/Embeddings/Request.php new file mode 100644 index 0000000..184544a --- /dev/null +++ b/src/Embeddings/Request.php @@ -0,0 +1,21 @@ + $clientOptions + * @param array{0: array|int, 1?: Closure|int, 2?: ?callable, 3?: bool} $clientRetry + */ + public function __construct( + public readonly string $model, + public readonly string $input, + public readonly array $clientOptions, + public readonly array $clientRetry, + ) {} +} diff --git a/src/Embeddings/Response.php b/src/Embeddings/Response.php new file mode 100644 index 0000000..23335ff --- /dev/null +++ b/src/Embeddings/Response.php @@ -0,0 +1,18 @@ + $embeddings + */ + public function __construct( + public readonly array $embeddings, + public readonly EmbeddingsUsage $usage, + ) {} +} diff --git a/src/Prism.php b/src/Prism.php index 2a5bc0b..57f2ea3 100644 --- a/src/Prism.php +++ b/src/Prism.php @@ -5,6 +5,7 @@ namespace EchoLabs\Prism; use EchoLabs\Prism\Contracts\Provider; +use EchoLabs\Prism\Embeddings\Generator as EmbeddingsGenerator; use EchoLabs\Prism\Enums\Provider as ProviderEnum; use EchoLabs\Prism\Providers\ProviderResponse; use EchoLabs\Prism\Structured\Generator as StructuredGenerator; @@ -44,4 +45,9 @@ public static function structured(): StructuredGenerator { return new StructuredGenerator; } + + public static function embeddings(): \EchoLabs\Prism\Embeddings\Generator + { + return new EmbeddingsGenerator; + } } diff --git a/src/Providers/Anthropic/Anthropic.php b/src/Providers/Anthropic/Anthropic.php index 09215ab..5a34015 100644 --- a/src/Providers/Anthropic/Anthropic.php +++ b/src/Providers/Anthropic/Anthropic.php @@ -5,6 +5,8 @@ namespace EchoLabs\Prism\Providers\Anthropic; use EchoLabs\Prism\Contracts\Provider; +use EchoLabs\Prism\Embeddings\Request as EmbeddingRequest; +use EchoLabs\Prism\Embeddings\Response as EmbeddingResponse; use EchoLabs\Prism\Providers\Anthropic\Handlers\Text; use EchoLabs\Prism\Providers\ProviderResponse; use EchoLabs\Prism\Structured\Request as StructuredRequest; @@ -33,6 +35,12 @@ public function structured(StructuredRequest $request): ProviderResponse throw new \Exception(sprintf('%s does not support structured mode', class_basename($this))); } + #[\Override] + public function embeddings(EmbeddingRequest $request): EmbeddingResponse + { + throw new \Exception(sprintf('%s does not support embeddings', class_basename($this))); + } + /** * @param array $options * @param array $retry diff --git a/src/Providers/Groq/Groq.php b/src/Providers/Groq/Groq.php index 7e77da0..67dd362 100644 --- a/src/Providers/Groq/Groq.php +++ b/src/Providers/Groq/Groq.php @@ -5,6 +5,8 @@ namespace EchoLabs\Prism\Providers\Groq; use EchoLabs\Prism\Contracts\Provider; +use EchoLabs\Prism\Embeddings\Request as EmbeddingRequest; +use EchoLabs\Prism\Embeddings\Response as EmbeddingResponse; use EchoLabs\Prism\Providers\Groq\Handlers\Text; use EchoLabs\Prism\Providers\ProviderResponse; use EchoLabs\Prism\Structured\Request as StructuredRequest; @@ -33,6 +35,12 @@ public function structured(StructuredRequest $request): ProviderResponse throw new \Exception(sprintf('%s does not support structured mode', class_basename($this))); } + #[\Override] + public function embeddings(EmbeddingRequest $request): EmbeddingResponse + { + throw new \Exception(sprintf('%s does not support embeddings', class_basename($this))); + } + /** * @param array $options * @param array $retry diff --git a/src/Providers/Mistral/Mistral.php b/src/Providers/Mistral/Mistral.php index ff5a456..92bde64 100644 --- a/src/Providers/Mistral/Mistral.php +++ b/src/Providers/Mistral/Mistral.php @@ -5,6 +5,8 @@ namespace EchoLabs\Prism\Providers\Mistral; use EchoLabs\Prism\Contracts\Provider; +use EchoLabs\Prism\Embeddings\Request as EmbeddingRequest; +use EchoLabs\Prism\Embeddings\Response as EmbeddingResponse; use EchoLabs\Prism\Providers\Mistral\Handlers\Text; use EchoLabs\Prism\Providers\ProviderResponse; use EchoLabs\Prism\Structured\Request as StructuredRequest; @@ -33,6 +35,12 @@ public function structured(StructuredRequest $request): ProviderResponse throw new \Exception(sprintf('%s does not support structured mode', class_basename($this))); } + #[\Override] + public function embeddings(EmbeddingRequest $request): EmbeddingResponse + { + throw new \Exception(sprintf('%s does not support embeddings', class_basename($this))); + } + /** * @param array $options * @param array $retry diff --git a/src/Providers/Ollama/Ollama.php b/src/Providers/Ollama/Ollama.php index cb0ff44..c5ef0eb 100644 --- a/src/Providers/Ollama/Ollama.php +++ b/src/Providers/Ollama/Ollama.php @@ -5,6 +5,8 @@ namespace EchoLabs\Prism\Providers\Ollama; use EchoLabs\Prism\Contracts\Provider; +use EchoLabs\Prism\Embeddings\Request as EmbeddingRequest; +use EchoLabs\Prism\Embeddings\Response as EmbeddingResponse; use EchoLabs\Prism\Providers\Ollama\Handlers\Text; use EchoLabs\Prism\Providers\ProviderResponse; use EchoLabs\Prism\Structured\Request as StructuredRequest; @@ -33,6 +35,12 @@ public function structured(StructuredRequest $request): ProviderResponse throw new \Exception(sprintf('%s does not support structured mode', class_basename($this))); } + #[\Override] + public function embeddings(EmbeddingRequest $request): EmbeddingResponse + { + throw new \Exception(sprintf('%s does not support embeddings', class_basename($this))); + } + /** * @param array $options * @param array $retry diff --git a/src/Providers/OpenAI/Handlers/Embeddings.php b/src/Providers/OpenAI/Handlers/Embeddings.php new file mode 100644 index 0000000..d4713bc --- /dev/null +++ b/src/Providers/OpenAI/Handlers/Embeddings.php @@ -0,0 +1,55 @@ +sendRequest($request); + } catch (Throwable $e) { + throw PrismException::providerRequestError($request->model, $e); + } + + $data = $response->json(); + + if (data_get($data, 'error') || ! $data) { + throw PrismException::providerResponseError(vsprintf( + 'OpenAI Error: [%s] %s', + [ + data_get($data, 'error.type', 'unknown'), + data_get($data, 'error.message', 'unknown'), + ] + )); + } + + return new EmbeddingsResponse( + embeddings: data_get($data, 'data.0.embedding', []), + usage: new EmbeddingsUsage(data_get($data, 'usage.total_tokens', null)), + ); + } + + protected function sendRequest(Request $request): Response + { + return $this->client->post( + 'embeddings', + [ + 'model' => $request->model, + 'input' => $request->input, + ] + ); + } +} diff --git a/src/Providers/OpenAI/OpenAI.php b/src/Providers/OpenAI/OpenAI.php index 5e0c67a..e0c71bc 100644 --- a/src/Providers/OpenAI/OpenAI.php +++ b/src/Providers/OpenAI/OpenAI.php @@ -6,6 +6,9 @@ use Closure; use EchoLabs\Prism\Contracts\Provider; +use EchoLabs\Prism\Embeddings\Request as EmbeddingsRequest; +use EchoLabs\Prism\Embeddings\Response as EmbeddingsResponse; +use EchoLabs\Prism\Providers\OpenAI\Handlers\Embeddings; use EchoLabs\Prism\Providers\OpenAI\Handlers\Structured; use EchoLabs\Prism\Providers\OpenAI\Handlers\Text; use EchoLabs\Prism\Providers\ProviderResponse; @@ -44,6 +47,17 @@ public function structured(StructuredRequest $request): ProviderResponse return $handler->handle($request); } + #[\Override] + public function embeddings(EmbeddingsRequest $request): EmbeddingsResponse + { + $handler = new Embeddings($this->client( + $request->clientOptions, + $request->clientRetry + )); + + return $handler->handle($request); + } + /** * @param array $options * @param array{0: array|int, 1?: Closure|int, 2?: ?callable, 3?: bool} $retry diff --git a/src/Providers/XAI/XAI.php b/src/Providers/XAI/XAI.php index 2aae530..8f9c4f4 100644 --- a/src/Providers/XAI/XAI.php +++ b/src/Providers/XAI/XAI.php @@ -5,6 +5,8 @@ namespace EchoLabs\Prism\Providers\XAI; use EchoLabs\Prism\Contracts\Provider; +use EchoLabs\Prism\Embeddings\Request as EmbeddingRequest; +use EchoLabs\Prism\Embeddings\Response as EmbeddingResponse; use EchoLabs\Prism\Providers\ProviderResponse; use EchoLabs\Prism\Providers\XAI\Handlers\Text; use EchoLabs\Prism\Structured\Request as StructuredRequest; @@ -33,6 +35,12 @@ public function structured(StructuredRequest $request): ProviderResponse throw new \Exception(sprintf('%s does not support structured mode', class_basename($this))); } + #[\Override] + public function embeddings(EmbeddingRequest $request): EmbeddingResponse + { + throw new \Exception(sprintf('%s does not support embeddings', class_basename($this))); + } + /** * @param array $options * @param array $retry diff --git a/src/Testing/PrismFake.php b/src/Testing/PrismFake.php index 723219d..a349da2 100644 --- a/src/Testing/PrismFake.php +++ b/src/Testing/PrismFake.php @@ -6,10 +6,13 @@ use Closure; use EchoLabs\Prism\Contracts\Provider; +use EchoLabs\Prism\Embeddings\Request as EmbeddingRequest; +use EchoLabs\Prism\Embeddings\Response as EmbeddingResponse; use EchoLabs\Prism\Enums\FinishReason; use EchoLabs\Prism\Providers\ProviderResponse; use EchoLabs\Prism\Structured\Request as StructuredRequest; use EchoLabs\Prism\Text\Request as TextRequest; +use EchoLabs\Prism\ValueObjects\EmbeddingsUsage; use EchoLabs\Prism\ValueObjects\Usage; use Exception; use PHPUnit\Framework\Assert as PHPUnit; @@ -18,11 +21,11 @@ class PrismFake implements Provider { protected int $responseSequence = 0; - /** @var array */ + /** @var array */ protected array $recorded = []; /** - * @param array $responses + * @param array $responses */ public function __construct(protected array $responses = []) {} @@ -31,7 +34,7 @@ public function text(TextRequest $request): ProviderResponse { $this->recorded[] = $request; - return $this->nextResponse() ?? new ProviderResponse( + return $this->nextProviderResponse() ?? new ProviderResponse( text: '', toolCalls: [], usage: new Usage(0, 0), @@ -40,12 +43,23 @@ public function text(TextRequest $request): ProviderResponse ); } + #[\Override] + public function embeddings(EmbeddingRequest $request): EmbeddingResponse + { + $this->recorded[] = $request; + + return $this->nextEmbeddingResponse() ?? new EmbeddingResponse( + embeddings: [], + usage: new EmbeddingsUsage(10), + ); + } + #[\Override] public function structured(StructuredRequest $request): ProviderResponse { $this->recorded[] = $request; - return $this->nextResponse() ?? new ProviderResponse( + return $this->nextProviderResponse() ?? new ProviderResponse( text: '', toolCalls: [], usage: new Usage(0, 0), @@ -55,7 +69,7 @@ public function structured(StructuredRequest $request): ProviderResponse } /** - * @param Closure(array):void $fn + * @param Closure(array):void $fn */ public function assertRequest(Closure $fn): void { @@ -85,12 +99,32 @@ public function assertCallCount(int $expectedCount): void PHPUnit::assertEquals($expectedCount, $actualCount, "Expected {$expectedCount} calls, got {$actualCount}"); } - protected function nextResponse(): ?ProviderResponse + protected function nextProviderResponse(): ?ProviderResponse + { + if (! isset($this->responses)) { + return null; + } + + /** @var ProviderResponse[] */ + $responses = $this->responses; + $sequence = $this->responseSequence; + + if (! isset($responses[$sequence])) { + throw new Exception('Could not find a response for the request'); + } + + $this->responseSequence++; + + return $responses[$sequence]; + } + + protected function nextEmbeddingResponse(): ?EmbeddingResponse { if (! isset($this->responses)) { return null; } + /** @var EmbeddingResponse[] */ $responses = $this->responses; $sequence = $this->responseSequence; diff --git a/src/ValueObjects/EmbeddingsUsage.php b/src/ValueObjects/EmbeddingsUsage.php new file mode 100644 index 0000000..f043054 --- /dev/null +++ b/src/ValueObjects/EmbeddingsUsage.php @@ -0,0 +1,12 @@ +using(Provider::Anthropic, 'claude-3-sonnet') + ->withPrompt('Tell me a short story about a brave knight.') + ->generate(); + +echo $response->text; +``` + +## System Prompts and Context + +System prompts help set the behavior and context for the AI. They're particularly useful for maintaining consistent responses or giving the LLM a persona: + +```php +$response = Prism::text() + ->using(Provider::Anthropic, 'claude-3-sonnet') + ->withSystemPrompt('You are an expert mathematician who explains concepts simply.') + ->withPrompt('Explain the Pythagorean theorem.') + ->generate(); +``` + +You can also use Laravel views for complex system prompts: + +```php +$response = Prism::text() + ->using(Provider::Anthropic, 'claude-3-sonnet') + ->withSystemPrompt(view('prompts.math-tutor')) + ->withPrompt('What is calculus?') + ->generate(); +``` + +You an also pass a View to the `withPrompt` method. + +## Message Chains and Conversations + +For interactive conversations, use message chains to maintain context: + +```php +use EchoLabs\Prism\ValueObjects\Messages\UserMessage; +use EchoLabs\Prism\ValueObjects\Messages\AssistantMessage; + +$response = Prism::text() + ->using(Provider::Anthropic, 'claude-3-sonnet') + ->withMessages([ + new UserMessage('What is JSON?'), + new AssistantMessage('JSON is a lightweight data format...'), + new UserMessage('Can you show me an example?') + ]) + ->generate(); +``` + +### Message Types + +- `SystemMessage` +- `UserMessage` +- `AssistantMessage` +- `ToolResultMessage` + +> [!NOTE] +> Some providers, like Anthropic, do not support the `SystemMessage` type. In those cases we convert `SystemMessage` to `UserMessage`. + +## Multi-modal Capabilities (Images) + +Prism supports including images in your messages for visual analysis: + +```php +use EchoLabs\Prism\ValueObjects\Messages\Support\Image; + +// From a local file +$message = new UserMessage( + "What's in this image?", + [Image::fromPath('/path/to/image.jpg')] +); + +// From a URL +$message = new UserMessage( + 'Analyze this diagram:', + [Image::fromUrl('https://example.com/diagram.png')] +); + +// From a Base64 +$image = base64_encode(file_get_contents('/path/to/image.jpg')); + +$message = new UserMessage( + 'Analyze this diagram:', + [Image::fromBase64($image)] +); + +$response = Prism::text() + ->using(Provider::Anthropic, 'claude-3-sonnet') + ->withMessages([$message]) + ->generate(); +``` + +## Generation Parameters + +Fine-tune your generations with various parameters: + +`withMaxTokens` + +Maximum number of tokens to generate. + +`usingTemperature` + +Temperature setting. + +The value is passed through to the provider. The range depends on the provider and model. For most providers, 0 means almost deterministic results, and higher values mean more randomness. + +> [!TIP] +> It is recommended to set either temperature or topP, but not both. + +`usingTopP` + +Nucleus sampling. + +The value is passed through to the provider. The range depends on the provider and model. For most providers, nucleus sampling is a number between 0 and 1. E.g. 0.1 would mean that only tokens with the top 10% probability mass are considered. + +> [!TIP] +> It is recommended to set either temperature or topP, but not both. + +`withClientOptions` + +Under the hood we use Laravel's [HTTP client](https://laravel.com/docs/11.x/http-client#main-content). You can use this method to pass any of Guzzles [request options](https://docs.guzzlephp.org/en/stable/request-options.html) e.g. `->withClientOptions(['timeout' => 30])`. + +`withClientRetry` + +Under the hood we use Laravel's [HTTP client](https://laravel.com/docs/11.x/http-client#main-content). You can use this method to set [retries](https://laravel.com/docs/11.x/http-client#retries) e.g. `->withClientRetry(3, 100)`. + +## Response Handling + +The response object provides rich access to the generation results: + +```php +$response = Prism::text() + ->using(Provider::Anthropic, 'claude-3-sonnet') + ->withPrompt('Explain quantum computing.') + ->generate(); + +// Access the generated text +echo $response->text; + +// Check why the generation stopped +echo $response->finishReason->name; + +// Get token usage statistics +echo "Prompt tokens: {$response->usage->promptTokens}"; +echo "Completion tokens: {$response->usage->completionTokens}"; + +// For multi-step generations, examine each step +foreach ($response->steps as $step) { + echo "Step text: {$step->text}"; + echo "Step tokens: {$step->usage->completionTokens}"; +} + +// Access message history +foreach ($response->responseMessages as $message) { + if ($message instanceof AssistantMessage) { + echo $message->content; + } +} +``` + +### Finish Reasons + +```php +case Stop; +case Length; +case ContentFilter; +case ToolCalls; +case Error; +case Other; +case Unknown; +``` + +## Error Handling + +Remember to handle potential errors in your generations: + +```php +use EchoLabs\Prism\Exceptions\PrismException; +use Throwable; + +try { + $response = Prism::text() + ->using(Provider::Anthropic, 'claude-3-sonnet') + ->withPrompt('Generate text...') + ->generate(); +} catch (PrismException $e) { + Log::error('Text generation failed:', ['error' => $e->getMessage()]); +} catch (Throwable $e) { + Log::error('Generic error:', ['error' => $e->getMessage]); +} +``` diff --git a/tests/Providers/OpenAI/EmbeddingsTest.php b/tests/Providers/OpenAI/EmbeddingsTest.php new file mode 100644 index 0000000..9fbf285 --- /dev/null +++ b/tests/Providers/OpenAI/EmbeddingsTest.php @@ -0,0 +1,39 @@ +set('prism.providers.openai.api_key', env('OPENAI_API_KEY')); +}); + +it('returns embeddings from input', function (): void { + FixtureResponse::fakeResponseSequence('v1/embeddings', 'openai/embeddings-input'); + + $response = Prism::embeddings() + ->using(Provider::OpenAI, 'text-embedding-ada-002') + ->fromInput('The food was delicious and the waiter...') + ->generate(); + + expect($response->embeddings)->toBeArray(); + expect($response->embeddings)->not->toBeEmpty(); + expect($response->usage->tokens)->toBe(8); +}); + +it('returns embeddings from file', function (): void { + FixtureResponse::fakeResponseSequence('v1/embeddings', 'openai/embeddings-file'); + + $response = Prism::embeddings() + ->using(Provider::OpenAI, 'text-embedding-ada-002') + ->fromFile('tests/Fixtures/test-embedding-file.md') + ->generate(); + + expect($response->embeddings)->toBeArray(); + expect($response->embeddings)->not->toBeEmpty(); + expect($response->usage->tokens)->toBe(1378); +}); diff --git a/tests/TestDoubles/TestProvider.php b/tests/TestDoubles/TestProvider.php index 881c3d5..cf71be4 100644 --- a/tests/TestDoubles/TestProvider.php +++ b/tests/TestDoubles/TestProvider.php @@ -5,15 +5,18 @@ namespace Tests\TestDoubles; use EchoLabs\Prism\Contracts\Provider; +use EchoLabs\Prism\Embeddings\Request as EmbeddingRequest; +use EchoLabs\Prism\Embeddings\Response as EmbeddingResponse; use EchoLabs\Prism\Enums\FinishReason; use EchoLabs\Prism\Providers\ProviderResponse; use EchoLabs\Prism\Structured\Request as StructuredRequest; use EchoLabs\Prism\Text\Request as TextRequest; +use EchoLabs\Prism\ValueObjects\EmbeddingsUsage; use EchoLabs\Prism\ValueObjects\Usage; class TestProvider implements Provider { - public StructuredRequest|TextRequest $request; + public StructuredRequest|TextRequest|EmbeddingRequest $request; /** @var array */ public array $clientOptions; @@ -21,7 +24,7 @@ class TestProvider implements Provider /** @var array */ public array $clientRetry; - /** @var array */ + /** @var array */ public array $responses = []; public $callCount = 0; @@ -58,6 +61,19 @@ public function structured(StructuredRequest $request): ProviderResponse ); } + #[\Override] + public function embeddings(EmbeddingRequest $request): EmbeddingResponse + { + $this->callCount++; + + $this->request = $request; + + return $this->responses[$this->callCount - 1] ?? new EmbeddingResponse( + embeddings: [], + usage: new EmbeddingsUsage(10), + ); + } + public function withResponse(ProviderResponse $response): Provider { $this->responses[] = $response; diff --git a/tests/Testing/PrismFakeTest.php b/tests/Testing/PrismFakeTest.php index b1caa9f..88a2016 100644 --- a/tests/Testing/PrismFakeTest.php +++ b/tests/Testing/PrismFakeTest.php @@ -4,13 +4,17 @@ namespace Tests\Testing; +use EchoLabs\Prism\Embeddings\Request as EmbeddingRequest; +use EchoLabs\Prism\Embeddings\Response as EmbeddingResponse; use EchoLabs\Prism\Enums\FinishReason; +use EchoLabs\Prism\Enums\Provider; use EchoLabs\Prism\Prism; use EchoLabs\Prism\Providers\ProviderResponse; use EchoLabs\Prism\Schema\ObjectSchema; use EchoLabs\Prism\Schema\StringSchema; use EchoLabs\Prism\Structured\Request as StructuredRequest; use EchoLabs\Prism\Text\Request as TextRequest; +use EchoLabs\Prism\ValueObjects\EmbeddingsUsage; use EchoLabs\Prism\ValueObjects\Usage; use Exception; @@ -69,6 +73,31 @@ }); }); +it('fake responses using the prism fake for emeddings', function (): void { + $fake = Prism::fake([ + new EmbeddingResponse( + embeddings: [ + -0.009639355, + -0.00047589254, + -0.022748338, + -0.005906468, + ], + usage: new EmbeddingsUsage(100) + ), + ]); + + Prism::embeddings() + ->using(Provider::OpenAI, 'text-embedding-ada-002') + ->fromInput('What is the meaning of life?') + ->generate(); + + $fake->assertCallCount(1); + $fake->assertRequest(function (array $requests): void { + expect($requests)->toHaveCount(1); + expect($requests[0])->toBeInstanceOf(EmbeddingRequest::class); + }); +}); + it("throws an exception when it can't runs out of responses", function (): void { $this->expectException(Exception::class); $this->expectExceptionMessage('Could not find a response for the request');