Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 68 additions & 23 deletions src/Providers/OpenRouter/Handlers/Structured.php
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,29 @@

use Illuminate\Http\Client\PendingRequest;
use Illuminate\Http\Client\Response;
use Prism\Prism\Concerns\CallsTools;
use Prism\Prism\Enums\FinishReason;
use Prism\Prism\Exceptions\PrismException;
use Prism\Prism\Exceptions\PrismStructuredDecodingException;
use Prism\Prism\Providers\DeepSeek\Maps\ToolCallMap;
use Prism\Prism\Providers\OpenRouter\Concerns\BuildsRequestOptions;
use Prism\Prism\Providers\OpenRouter\Concerns\MapsFinishReason;
use Prism\Prism\Providers\OpenRouter\Concerns\ValidatesResponses;
use Prism\Prism\Providers\OpenRouter\Maps\FinishReasonMap;
use Prism\Prism\Providers\OpenRouter\Maps\MessageMap;
use Prism\Prism\Structured\Request;
use Prism\Prism\Structured\Response as StructuredResponse;
use Prism\Prism\Structured\ResponseBuilder;
use Prism\Prism\Structured\Step;
use Prism\Prism\ValueObjects\Messages\AssistantMessage;
use Prism\Prism\ValueObjects\Messages\ToolResultMessage;
use Prism\Prism\ValueObjects\Meta;
use Prism\Prism\ValueObjects\ToolResult;
use Prism\Prism\ValueObjects\Usage;

class Structured
{
use BuildsRequestOptions;
use CallsTools;
use MapsFinishReason;
use ValidatesResponses;

Expand All @@ -40,7 +45,11 @@ public function handle(Request $request): StructuredResponse

$this->validateResponse($data);

return $this->createResponse($request, $data);
return match ($this->mapFinishReason($data)) {
FinishReason::ToolCalls => $this->handleToolCalls($data, $request),
FinishReason::Stop, FinishReason::Length => $this->handleStop($data, $request),
default => throw new PrismException('OpenRouter: unknown finish reason'),
};
}

/**
Expand Down Expand Up @@ -86,31 +95,35 @@ protected function validateResponse(array $data): void
/**
* @param array<string, mixed> $data
*/
protected function createResponse(Request $request, array $data): StructuredResponse
protected function handleToolCalls(array $data, Request $request): StructuredResponse
{
$text = data_get($data, 'choices.0.message.content') ?? '';
$toolCalls = ToolCallMap::map(data_get($data, 'choices.0.message.tool_calls', []));

$responseMessage = new AssistantMessage($text);
$request->addMessage($responseMessage);
$toolResults = $this->callTools($request->tools(), $toolCalls);

$step = new Step(
text: $text,
finishReason: FinishReasonMap::map(data_get($data, 'choices.0.finish_reason', '')),
usage: new Usage(
(int) data_get($data, 'usage.prompt_tokens', 0),
(int) data_get($data, 'usage.completion_tokens', 0),
),
meta: new Meta(
id: data_get($data, 'id', ''),
model: data_get($data, 'model', $request->model()),
),
messages: $request->messages(),
systemPrompts: $request->systemPrompts(),
additionalContent: [],
raw: $data
);
$this->addStep($data, $request, $toolResults);

$request = $request->addMessage(new AssistantMessage(
data_get($data, 'choices.0.message.content') ?? '',
$toolCalls,
[]
));
$request = $request->addMessage(new ToolResultMessage($toolResults));
$request->resetToolChoice();

if ($this->shouldContinue($request)) {
return $this->handle($request);
}

return $this->responseBuilder->toResponse();
}

$this->responseBuilder->addStep($step);
/**
* @param array<string, mixed> $data
*/
protected function handleStop(array $data, Request $request): StructuredResponse
{
$this->addStep($data, $request);

try {
return $this->responseBuilder->toResponse();
Expand All @@ -125,4 +138,36 @@ protected function createResponse(Request $request, array $data): StructuredResp
throw new PrismStructuredDecodingException($e->getMessage().$context);
}
}

protected function shouldContinue(Request $request): bool
{
return $this->responseBuilder->steps->count() < $request->maxSteps();
}

/**
* @param array<string, mixed> $data
* @param array<int, ToolResult> $toolResults
*/
protected function addStep(array $data, Request $request, array $toolResults = []): void
{
$this->responseBuilder->addStep(new Step(
text: data_get($data, 'choices.0.message.content') ?? '',
finishReason: $this->mapFinishReason($data),
usage: new Usage(
(int) data_get($data, 'usage.prompt_tokens', 0),
(int) data_get($data, 'usage.completion_tokens', 0),
),
meta: new Meta(
id: data_get($data, 'id', ''),
model: data_get($data, 'model', $request->model()),
),
messages: $request->messages(),
systemPrompts: $request->systemPrompts(),
additionalContent: [],
toolCalls: ToolCallMap::map(data_get($data, 'choices.0.message.tool_calls', [])),
providerToolCalls: [],
toolResults: $toolResults,
raw: $data,
));
}
}
4 changes: 4 additions & 0 deletions src/Structured/Response.php
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
use Illuminate\Support\Collection;
use Prism\Prism\Enums\FinishReason;
use Prism\Prism\ValueObjects\Meta;
use Prism\Prism\ValueObjects\ProviderToolCall;
use Prism\Prism\ValueObjects\ToolCall;
use Prism\Prism\ValueObjects\ToolResult;
use Prism\Prism\ValueObjects\Usage;
Expand All @@ -21,6 +22,7 @@
* @param Collection<int, Step> $steps
* @param array<mixed> $structured
* @param array<int, ToolCall> $toolCalls
* @param array<int, ProviderToolCall> $providerToolCalls
* @param array<int, ToolResult> $toolResults
* @param array<string,mixed> $additionalContent
* @param array<string,mixed>|null $raw
Expand All @@ -33,6 +35,7 @@ public function __construct(
public Usage $usage,
public Meta $meta,
public array $toolCalls = [],
public array $providerToolCalls = [],
public array $toolResults = [],
public array $additionalContent = [],
public ?array $raw = null
Expand All @@ -52,6 +55,7 @@ public function toArray(): array
'usage' => $this->usage->toArray(),
'meta' => $this->meta->toArray(),
'tool_calls' => array_map(fn (ToolCall $toolCall): array => $toolCall->toArray(), $this->toolCalls),
'provider_tool_calls' => array_map(fn (ProviderToolCall $providerToolCall): array => $providerToolCall->toArray(), $this->providerToolCalls),
'tool_results' => array_map(fn (ToolResult $toolResult): array => $toolResult->toArray(), $this->toolResults),
'additional_content' => $this->additionalContent,
'raw' => $this->raw,
Expand Down
4 changes: 4 additions & 0 deletions src/Structured/Step.php
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
use Prism\Prism\ValueObjects\Messages\ToolResultMessage;
use Prism\Prism\ValueObjects\Messages\UserMessage;
use Prism\Prism\ValueObjects\Meta;
use Prism\Prism\ValueObjects\ProviderToolCall;
use Prism\Prism\ValueObjects\ToolCall;
use Prism\Prism\ValueObjects\ToolResult;
use Prism\Prism\ValueObjects\Usage;
Expand All @@ -27,6 +28,7 @@
* @param array<string,mixed> $additionalContent
* @param array<string,mixed> $structured
* @param array<int, ToolCall> $toolCalls
* @param array<int, ProviderToolCall> $providerToolCalls
* @param array<int, ToolResult> $toolResults
* @param array<string,mixed>|null $raw
*/
Expand All @@ -40,6 +42,7 @@ public function __construct(
public array $additionalContent = [],
public array $structured = [],
public array $toolCalls = [],
public array $providerToolCalls = [],
public array $toolResults = [],
public ?array $raw = null
) {}
Expand All @@ -61,6 +64,7 @@ public function toArray(): array
'structured' => $this->structured,
'tool_calls' => array_map(fn (ToolCall $toolCall): array => $toolCall->toArray(), $this->toolCalls),
'tool_results' => array_map(fn (ToolResult $toolResult): array => $toolResult->toArray(), $this->toolResults),
'provider_tool_calls' => array_map(fn (ProviderToolCall $providerToolCall): array => $providerToolCall->toArray(), $this->providerToolCalls),
'raw' => $this->raw,
];
}
Expand Down