diff --git a/src/Providers/OpenRouter/Handlers/Structured.php b/src/Providers/OpenRouter/Handlers/Structured.php index d377df9c..1576268c 100644 --- a/src/Providers/OpenRouter/Handlers/Structured.php +++ b/src/Providers/OpenRouter/Handlers/Structured.php @@ -6,24 +6,29 @@ use Illuminate\Http\Client\PendingRequest; use Illuminate\Http\Client\Response; +use Prism\Prism\Concerns\CallsTools; +use Prism\Prism\Enums\FinishReason; use Prism\Prism\Exceptions\PrismException; use Prism\Prism\Exceptions\PrismStructuredDecodingException; +use Prism\Prism\Providers\DeepSeek\Maps\ToolCallMap; use Prism\Prism\Providers\OpenRouter\Concerns\BuildsRequestOptions; use Prism\Prism\Providers\OpenRouter\Concerns\MapsFinishReason; use Prism\Prism\Providers\OpenRouter\Concerns\ValidatesResponses; -use Prism\Prism\Providers\OpenRouter\Maps\FinishReasonMap; use Prism\Prism\Providers\OpenRouter\Maps\MessageMap; use Prism\Prism\Structured\Request; use Prism\Prism\Structured\Response as StructuredResponse; use Prism\Prism\Structured\ResponseBuilder; use Prism\Prism\Structured\Step; use Prism\Prism\ValueObjects\Messages\AssistantMessage; +use Prism\Prism\ValueObjects\Messages\ToolResultMessage; use Prism\Prism\ValueObjects\Meta; +use Prism\Prism\ValueObjects\ToolResult; use Prism\Prism\ValueObjects\Usage; class Structured { use BuildsRequestOptions; + use CallsTools; use MapsFinishReason; use ValidatesResponses; @@ -40,7 +45,11 @@ public function handle(Request $request): StructuredResponse $this->validateResponse($data); - return $this->createResponse($request, $data); + return match ($this->mapFinishReason($data)) { + FinishReason::ToolCalls => $this->handleToolCalls($data, $request), + FinishReason::Stop, FinishReason::Length => $this->handleStop($data, $request), + default => throw new PrismException('OpenRouter: unknown finish reason'), + }; } /** @@ -86,31 +95,35 @@ protected function validateResponse(array $data): void /** * @param array $data */ - protected function createResponse(Request $request, array $data): StructuredResponse + protected function handleToolCalls(array $data, Request $request): StructuredResponse { - $text = data_get($data, 'choices.0.message.content') ?? ''; + $toolCalls = ToolCallMap::map(data_get($data, 'choices.0.message.tool_calls', [])); - $responseMessage = new AssistantMessage($text); - $request->addMessage($responseMessage); + $toolResults = $this->callTools($request->tools(), $toolCalls); - $step = new Step( - text: $text, - finishReason: FinishReasonMap::map(data_get($data, 'choices.0.finish_reason', '')), - usage: new Usage( - (int) data_get($data, 'usage.prompt_tokens', 0), - (int) data_get($data, 'usage.completion_tokens', 0), - ), - meta: new Meta( - id: data_get($data, 'id', ''), - model: data_get($data, 'model', $request->model()), - ), - messages: $request->messages(), - systemPrompts: $request->systemPrompts(), - additionalContent: [], - raw: $data - ); + $this->addStep($data, $request, $toolResults); + + $request = $request->addMessage(new AssistantMessage( + data_get($data, 'choices.0.message.content') ?? '', + $toolCalls, + [] + )); + $request = $request->addMessage(new ToolResultMessage($toolResults)); + $request->resetToolChoice(); + + if ($this->shouldContinue($request)) { + return $this->handle($request); + } + + return $this->responseBuilder->toResponse(); + } - $this->responseBuilder->addStep($step); + /** + * @param array $data + */ + protected function handleStop(array $data, Request $request): StructuredResponse + { + $this->addStep($data, $request); try { return $this->responseBuilder->toResponse(); @@ -125,4 +138,36 @@ protected function createResponse(Request $request, array $data): StructuredResp throw new PrismStructuredDecodingException($e->getMessage().$context); } } + + protected function shouldContinue(Request $request): bool + { + return $this->responseBuilder->steps->count() < $request->maxSteps(); + } + + /** + * @param array $data + * @param array $toolResults + */ + protected function addStep(array $data, Request $request, array $toolResults = []): void + { + $this->responseBuilder->addStep(new Step( + text: data_get($data, 'choices.0.message.content') ?? '', + finishReason: $this->mapFinishReason($data), + usage: new Usage( + (int) data_get($data, 'usage.prompt_tokens', 0), + (int) data_get($data, 'usage.completion_tokens', 0), + ), + meta: new Meta( + id: data_get($data, 'id', ''), + model: data_get($data, 'model', $request->model()), + ), + messages: $request->messages(), + systemPrompts: $request->systemPrompts(), + additionalContent: [], + toolCalls: ToolCallMap::map(data_get($data, 'choices.0.message.tool_calls', [])), + providerToolCalls: [], + toolResults: $toolResults, + raw: $data, + )); + } } diff --git a/src/Structured/Response.php b/src/Structured/Response.php index 688e8010..b3c0c103 100644 --- a/src/Structured/Response.php +++ b/src/Structured/Response.php @@ -8,6 +8,7 @@ use Illuminate\Support\Collection; use Prism\Prism\Enums\FinishReason; use Prism\Prism\ValueObjects\Meta; +use Prism\Prism\ValueObjects\ProviderToolCall; use Prism\Prism\ValueObjects\ToolCall; use Prism\Prism\ValueObjects\ToolResult; use Prism\Prism\ValueObjects\Usage; @@ -21,6 +22,7 @@ * @param Collection $steps * @param array $structured * @param array $toolCalls + * @param array $providerToolCalls * @param array $toolResults * @param array $additionalContent * @param array|null $raw @@ -33,6 +35,7 @@ public function __construct( public Usage $usage, public Meta $meta, public array $toolCalls = [], + public array $providerToolCalls = [], public array $toolResults = [], public array $additionalContent = [], public ?array $raw = null @@ -52,6 +55,7 @@ public function toArray(): array 'usage' => $this->usage->toArray(), 'meta' => $this->meta->toArray(), 'tool_calls' => array_map(fn (ToolCall $toolCall): array => $toolCall->toArray(), $this->toolCalls), + 'provider_tool_calls' => array_map(fn (ProviderToolCall $providerToolCall): array => $providerToolCall->toArray(), $this->providerToolCalls), 'tool_results' => array_map(fn (ToolResult $toolResult): array => $toolResult->toArray(), $this->toolResults), 'additional_content' => $this->additionalContent, 'raw' => $this->raw, diff --git a/src/Structured/Step.php b/src/Structured/Step.php index 70e14875..7d1a4fb3 100644 --- a/src/Structured/Step.php +++ b/src/Structured/Step.php @@ -12,6 +12,7 @@ use Prism\Prism\ValueObjects\Messages\ToolResultMessage; use Prism\Prism\ValueObjects\Messages\UserMessage; use Prism\Prism\ValueObjects\Meta; +use Prism\Prism\ValueObjects\ProviderToolCall; use Prism\Prism\ValueObjects\ToolCall; use Prism\Prism\ValueObjects\ToolResult; use Prism\Prism\ValueObjects\Usage; @@ -27,6 +28,7 @@ * @param array $additionalContent * @param array $structured * @param array $toolCalls + * @param array $providerToolCalls * @param array $toolResults * @param array|null $raw */ @@ -40,6 +42,7 @@ public function __construct( public array $additionalContent = [], public array $structured = [], public array $toolCalls = [], + public array $providerToolCalls = [], public array $toolResults = [], public ?array $raw = null ) {} @@ -61,6 +64,7 @@ public function toArray(): array 'structured' => $this->structured, 'tool_calls' => array_map(fn (ToolCall $toolCall): array => $toolCall->toArray(), $this->toolCalls), 'tool_results' => array_map(fn (ToolResult $toolResult): array => $toolResult->toArray(), $this->toolResults), + 'provider_tool_calls' => array_map(fn (ProviderToolCall $providerToolCall): array => $providerToolCall->toArray(), $this->providerToolCalls), 'raw' => $this->raw, ]; }