Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions examples/misc/failover-transport.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
<?php

/*
* This file is part of the Symfony package.
*
* (c) Fabien Potencier <[email protected]>
*
* For the full copyright and license information, please view the LICENSE
* file that was distributed with this source code.
*/

use Symfony\AI\Platform\Bridge\Ollama\PlatformFactory as OllamaPlatformFactory;
use Symfony\AI\Platform\Bridge\OpenAi\PlatformFactory as OpenAiPlatformFactory;
use Symfony\AI\Platform\FailoverPlatform;
use Symfony\AI\Platform\Message\Message;
use Symfony\AI\Platform\Message\MessageBag;

require_once dirname(__DIR__).'/bootstrap.php';

$ollamaPlatform = OllamaPlatformFactory::create(env('OLLAMA_HOST_URL'), http_client());
$openAiPlatform = OpenAiPlatformFactory::create(env('OPENAI_API_KEY'), http_client());

$platform = new FailoverPlatform([
$ollamaPlatform, // # Ollama will fail as 'gpt-4o' is not available in the catalog
$openAiPlatform,
]);

$result = $platform->invoke('gpt-4o', new MessageBag(
Message::forSystem('You are a helpful assistant.'),
Message::ofUser('Tina has one brother and one sister. How many sisters do Tina\'s siblings have?'),
));

echo $result->asText().\PHP_EOL;
10 changes: 10 additions & 0 deletions src/ai-bundle/config/options.php
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,16 @@
->end()
->end()
->end()
->arrayNode('failover')
->children()
->integerNode('retry_period')
->defaultValue(60)
->end()
->arrayNode('platforms')
->scalarPrototype()->end()
->end()
->end()
->end()
->arrayNode('gemini')
->children()
->stringNode('api_key')->isRequired()->end()
Expand Down
23 changes: 23 additions & 0 deletions src/ai-bundle/src/AiBundle.php
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

use Google\Auth\ApplicationDefaultCredentials;
use Google\Auth\FetchAuthTokenInterface;
use Psr\Log\LoggerInterface;
use Symfony\AI\Agent\Agent;
use Symfony\AI\Agent\AgentInterface;
use Symfony\AI\Agent\Attribute\AsInputProcessor;
Expand Down Expand Up @@ -76,6 +77,7 @@
use Symfony\AI\Platform\CachedPlatform;
use Symfony\AI\Platform\Capability;
use Symfony\AI\Platform\Exception\RuntimeException;
use Symfony\AI\Platform\FailoverPlatform;
use Symfony\AI\Platform\Message\Content\File;
use Symfony\AI\Platform\ModelCatalog\ModelCatalogInterface;
use Symfony\AI\Platform\ModelClientInterface;
Expand Down Expand Up @@ -488,6 +490,27 @@ private function processPlatformConfig(string $type, array $platform, ContainerB
return;
}

if ('failover' === $type) {
$definition = (new Definition(FailoverPlatform::class))
->setLazy(true)
->setArguments([
array_map(
static fn (string $platform): Reference => new Reference($platform),
$platform['platforms'],
),
new Reference(ClockInterface::class),
$platform['retry_period'],
new Reference(LoggerInterface::class),
])
->addTag('proxy', ['interface' => PlatformInterface::class])
->addTag('ai.platform', ['name' => $type]);

$container->setDefinition('ai.platform.'.$type, $definition);
$container->registerAliasForArgument('ai.platform.'.$type, PlatformInterface::class, $type);

return;
}

if ('gemini' === $type) {
$platformId = 'ai.platform.gemini';
$definition = (new Definition(Platform::class))
Expand Down
141 changes: 141 additions & 0 deletions src/ai-bundle/tests/DependencyInjection/AiBundleTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
use PHPUnit\Framework\Attributes\TestWith;
use PHPUnit\Framework\TestCase;
use Probots\Pinecone\Client as PineconeClient;
use Psr\Log\LoggerInterface;
use Psr\Log\NullLogger;
use Symfony\AI\Agent\AgentInterface;
use Symfony\AI\Agent\Memory\MemoryInputProcessor;
use Symfony\AI\Agent\Memory\StaticMemoryProvider;
Expand All @@ -34,6 +36,7 @@
use Symfony\AI\Platform\Bridge\Ollama\OllamaApiCatalog;
use Symfony\AI\Platform\Capability;
use Symfony\AI\Platform\EventListener\TemplateRendererListener;
use Symfony\AI\Platform\FailoverPlatform;
use Symfony\AI\Platform\Message\TemplateRenderer\ExpressionLanguageTemplateRenderer;
use Symfony\AI\Platform\Message\TemplateRenderer\StringTemplateRenderer;
use Symfony\AI\Platform\Message\TemplateRenderer\TemplateRendererRegistry;
Expand Down Expand Up @@ -74,6 +77,8 @@
use Symfony\AI\Store\ManagedStoreInterface;
use Symfony\AI\Store\RetrieverInterface;
use Symfony\AI\Store\StoreInterface;
use Symfony\Component\Clock\ClockInterface;
use Symfony\Component\Clock\MonotonicClock;
use Symfony\Component\Config\Definition\Exception\InvalidConfigurationException;
use Symfony\Component\DependencyInjection\ContainerBuilder;
use Symfony\Component\DependencyInjection\ContainerInterface;
Expand Down Expand Up @@ -4016,6 +4021,133 @@ public function testElevenLabsPlatformWithApiCatalogCanBeRegistered()
$this->assertSame([['interface' => ModelCatalogInterface::class]], $modelCatalogDefinition->getTag('proxy'));
}

public function testFailoverPlatformCanBeCreated()
{
$container = $this->buildContainer([
'ai' => [
'platform' => [
'ollama' => [
'host_url' => 'http://127.0.0.1:11434',
],
'openai' => [
'api_key' => 'sk-openai_key_full',
],
'failover' => [
'platforms' => [
'ai.platform.ollama',
'ai.platform.openai',
],
],
],
],
]);

$this->assertTrue($container->hasDefinition('ai.platform.failover'));

$definition = $container->getDefinition('ai.platform.failover');

$this->assertTrue($definition->isLazy());
$this->assertSame(FailoverPlatform::class, $definition->getClass());

$this->assertCount(4, $definition->getArguments());
$this->assertCount(2, $definition->getArgument(0));
$this->assertEquals([
new Reference('ai.platform.ollama'),
new Reference('ai.platform.openai'),
], $definition->getArgument(0));
$this->assertInstanceOf(Reference::class, $definition->getArgument(1));
$this->assertSame(ClockInterface::class, (string) $definition->getArgument(1));
$this->assertSame(60, $definition->getArgument(2));
$this->assertInstanceOf(Reference::class, $definition->getArgument(3));
$this->assertSame(LoggerInterface::class, (string) $definition->getArgument(3));

$this->assertTrue($definition->hasTag('proxy'));
$this->assertSame([['interface' => PlatformInterface::class]], $definition->getTag('proxy'));
$this->assertTrue($definition->hasTag('ai.platform'));
$this->assertSame([['name' => 'failover']], $definition->getTag('ai.platform'));

$this->assertTrue($container->hasAlias('Symfony\AI\Platform\PlatformInterface $failover'));

$container = $this->buildContainer([
'ai' => [
'platform' => [
'ollama' => [
'host_url' => 'http://127.0.0.1:11434',
],
'openai' => [
'api_key' => 'sk-openai_key_full',
],
'failover' => [
'platforms' => [
'ai.platform.ollama',
'ai.platform.openai',
],
'retry_period' => 120,
],
],
],
]);

$this->assertTrue($container->hasDefinition('ai.platform.failover'));

$definition = $container->getDefinition('ai.platform.failover');

$this->assertTrue($definition->isLazy());
$this->assertSame(FailoverPlatform::class, $definition->getClass());

$this->assertCount(4, $definition->getArguments());
$this->assertCount(2, $definition->getArgument(0));
$this->assertInstanceOf(Reference::class, $definition->getArgument(1));
$this->assertSame(ClockInterface::class, (string) $definition->getArgument(1));
$this->assertSame(120, $definition->getArgument(2));
$this->assertInstanceOf(Reference::class, $definition->getArgument(3));
$this->assertSame(LoggerInterface::class, (string) $definition->getArgument(3));

$this->assertTrue($definition->hasTag('proxy'));
$this->assertSame([['interface' => PlatformInterface::class]], $definition->getTag('proxy'));
$this->assertTrue($definition->hasTag('ai.platform'));
$this->assertSame([['name' => 'failover']], $definition->getTag('ai.platform'));

$this->assertTrue($container->hasAlias('Symfony\AI\Platform\PlatformInterface $failover'));
}

#[TestDox('Token usage processor tags use the correct agent ID')]
public function testTokenUsageProcessorTags()
{
$container = $this->buildContainer([
'ai' => [
'platform' => [
'openai' => [
'api_key' => 'sk-test_key',
],
],
'agent' => [
'tracked_agent' => [
'platform' => 'ai.platform.openai',
'model' => 'gpt-4',
'track_token_usage' => true,
],
],
],
]);

$agentId = 'ai.agent.tracked_agent';

// Token usage processor must exist for OpenAI platform
$tokenUsageProcessor = $container->getDefinition('ai.platform.token_usage_processor.openai');
$outputTags = $tokenUsageProcessor->getTag('ai.agent.output_processor');

$foundTag = false;
foreach ($outputTags as $tag) {
if (($tag['agent'] ?? '') === $agentId) {
$foundTag = true;
break;
}
}

$this->assertTrue($foundTag, 'Token usage processor should have output tag with full agent ID');
}

public function testOpenAiPlatformWithDefaultRegion()
{
$container = $this->buildContainer([
Expand Down Expand Up @@ -6987,6 +7119,8 @@ private function buildContainer(array $configuration): ContainerBuilder
$container->setParameter('kernel.debug', true);
$container->setParameter('kernel.environment', 'dev');
$container->setParameter('kernel.build_dir', 'public');
$container->setDefinition(ClockInterface::class, new Definition(MonotonicClock::class));
$container->setDefinition(LoggerInterface::class, new Definition(NullLogger::class));

$extension = (new AiBundle())->getContainerExtension();
$extension->load($configuration, $container);
Expand Down Expand Up @@ -7042,6 +7176,13 @@ private function getFullConfig(): array
'host' => 'https://api.elevenlabs.io/v1',
'api_key' => 'elevenlabs_key_full',
],
'failover' => [
'platforms' => [
'ai.platform.ollama',
'ai.platform.openai',
],
'retry_period' => 120,
],
'gemini' => [
'api_key' => 'gemini_key_full',
],
Expand Down
19 changes: 19 additions & 0 deletions src/platform/src/Exception/LogicException.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
<?php

/*
* This file is part of the Symfony package.
*
* (c) Fabien Potencier <[email protected]>
*
* For the full copyright and license information, please view the LICENSE
* file that was distributed with this source code.
*/

namespace Symfony\AI\Platform\Exception;

/**
* @author Guillaume Loulier <[email protected]>
*/
final class LogicException extends \LogicException implements ExceptionInterface
{
}
86 changes: 86 additions & 0 deletions src/platform/src/FailoverPlatform.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
<?php

/*
* This file is part of the Symfony package.
*
* (c) Fabien Potencier <[email protected]>
*
* For the full copyright and license information, please view the LICENSE
* file that was distributed with this source code.
*/

namespace Symfony\AI\Platform;

use Psr\Log\LoggerInterface;
use Psr\Log\NullLogger;
use Symfony\AI\Platform\Exception\LogicException;
use Symfony\AI\Platform\Exception\RuntimeException;
use Symfony\AI\Platform\ModelCatalog\ModelCatalogInterface;
use Symfony\AI\Platform\Result\DeferredResult;
use Symfony\Component\Clock\ClockInterface;
use Symfony\Component\Clock\MonotonicClock;

/**
* @author Guillaume Loulier <[email protected]>
*/
final class FailoverPlatform implements PlatformInterface
{
/**
* @var \SplObjectStorage<PlatformInterface, int>
*/
private \SplObjectStorage $deadPlatforms;

/**
* @param PlatformInterface[] $platforms
*/
public function __construct(
private readonly iterable $platforms,
private readonly ClockInterface $clock = new MonotonicClock(),
private readonly int $retryPeriod = 60,
private readonly LoggerInterface $logger = new NullLogger(),
) {
if ([] === $platforms) {
throw new LogicException(\sprintf('"%s" must have at least one platform configured.', self::class));
}

$this->deadPlatforms = new \SplObjectStorage();
}

public function invoke(string $model, object|array|string $input, array $options = []): DeferredResult
{
return $this->do(static fn (PlatformInterface $platform): DeferredResult => $platform->invoke($model, $input, $options));
}

public function getModelCatalog(): ModelCatalogInterface
{
return $this->do(static fn (PlatformInterface $platform): ModelCatalogInterface => $platform->getModelCatalog());
}

private function do(\Closure $func): DeferredResult|ModelCatalogInterface
{
foreach ($this->platforms as $platform) {
if ($this->deadPlatforms->offsetExists($platform) && ($this->clock->now()->getTimestamp() - $this->deadPlatforms[$platform]) > $this->retryPeriod) {
$this->deadPlatforms->offsetUnset($platform);
}

if ($this->deadPlatforms->offsetExists($platform)) {
continue;
}

try {
return $func($platform);
} catch (\Throwable $throwable) {
$this->deadPlatforms->offsetSet($platform, $this->clock->now()->getTimestamp());

$this->logger->warning('The {platform} platform has encountered an exception: {exception}', [
'platform' => $platform::class,
'exception' => $throwable->getMessage(),
]);

continue;
}
}

throw new RuntimeException('All platforms failed.');
}
}
Loading
Loading