Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions go/adk/pkg/a2a/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@ import (
"fmt"
"maps"
"os"
"strings"

a2atype "github.com/a2aproject/a2a-go/a2a"
"github.com/a2aproject/a2a-go/a2asrv"
"github.com/a2aproject/a2a-go/a2asrv/eventqueue"
"github.com/go-logr/logr"
"github.com/kagent-dev/kagent/go/adk/pkg/models"
"github.com/kagent-dev/kagent/go/adk/pkg/session"
"github.com/kagent-dev/kagent/go/adk/pkg/skills"
"github.com/kagent-dev/kagent/go/adk/pkg/telemetry"
Expand Down Expand Up @@ -114,6 +116,19 @@ func (e *KAgentExecutor) Execute(ctx context.Context, reqCtx *a2asrv.RequestCont
}
sessionID := reqCtx.ContextID

// Extract Bearer token from incoming request for API key passthrough
if callCtx, ok := a2asrv.CallContextFrom(ctx); ok {
if meta := callCtx.RequestMeta(); meta != nil {
if vals, ok := meta.Get("authorization"); ok && len(vals) > 0 && vals[0] != "" {
auth := strings.TrimSpace(vals[0])
parts := strings.Fields(auth)
if len(parts) >= 2 && strings.EqualFold(parts[0], "Bearer") {
ctx = context.WithValue(ctx, models.BearerTokenKey, parts[1])
}
}
}
}

e.logger.Info("Execute",
"taskID", reqCtx.TaskID,
"contextID", reqCtx.ContextID,
Expand Down
82 changes: 50 additions & 32 deletions go/adk/pkg/agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,21 +101,25 @@ func CreateGoogleADKAgentWithSubagentSessionIDs(ctx context.Context, agentConfig

// Build BeforeToolCallbacks. Approval gating runs first.
beforeToolCallbacks := []llmagent.BeforeToolCallback{}
// Strip synthetic HITL tool messages from the model request to avoid unnecessary token usage.
beforeModelCallbacks := []llmagent.BeforeModelCallback{}
if len(approvalSet) > 0 {
log.Info("Wiring approval callback", "toolCount", len(approvalSet))
beforeToolCallbacks = append(beforeToolCallbacks, MakeApprovalCallback(approvalSet))
beforeModelCallbacks = append(beforeModelCallbacks, MakeStripConfirmationPartsCallback())
}
beforeToolCallbacks = append(beforeToolCallbacks, makeBeforeToolCallback(log))

llmAgentConfig := llmagent.Config{
Name: agentName,
Description: agentConfig.Description,
Instruction: agentConfig.Instruction,
Model: llmModel,
IncludeContents: llmagent.IncludeContentsDefault,
Tools: localTools,
Toolsets: toolsets,
BeforeToolCallbacks: beforeToolCallbacks,
Name: agentName,
Description: agentConfig.Description,
Instruction: agentConfig.Instruction,
Model: llmModel,
IncludeContents: llmagent.IncludeContentsDefault,
Tools: localTools,
Toolsets: toolsets,
BeforeToolCallbacks: beforeToolCallbacks,
BeforeModelCallbacks: beforeModelCallbacks,
AfterToolCallbacks: []llmagent.AfterToolCallback{
makeAfterToolCallback(log),
},
Expand Down Expand Up @@ -179,26 +183,24 @@ func CreateLLM(ctx context.Context, m adk.Model, log logr.Logger) (adkmodel.LLM,
switch m := m.(type) {
case *adk.OpenAI:
cfg := &models.OpenAIConfig{
TransportConfig: transportConfigFromBase(m.BaseModel, m.Timeout),
Model: m.Model,
BaseUrl: m.BaseUrl,
Headers: extractHeaders(m.Headers),
FrequencyPenalty: m.FrequencyPenalty,
MaxTokens: m.MaxTokens,
N: m.N,
PresencePenalty: m.PresencePenalty,
ReasoningEffort: m.ReasoningEffort,
Seed: m.Seed,
Temperature: m.Temperature,
Timeout: m.Timeout,
TopP: m.TopP,
}
return models.NewOpenAIModelWithLogger(cfg, log)

case *adk.AzureOpenAI:
cfg := &models.AzureOpenAIConfig{
Model: m.Model,
Headers: extractHeaders(m.Headers),
Timeout: nil,
TransportConfig: transportConfigFromBase(m.BaseModel, nil),
Model: m.Model,
}
return models.NewAzureOpenAIModelWithLogger(cfg, log)

Expand Down Expand Up @@ -241,14 +243,13 @@ func CreateLLM(ctx context.Context, m adk.Model, log logr.Logger) (adkmodel.LLM,
modelName = DefaultAnthropicModel
}
cfg := &models.AnthropicConfig{
Model: modelName,
BaseUrl: m.BaseUrl,
Headers: extractHeaders(m.Headers),
MaxTokens: m.MaxTokens,
Temperature: m.Temperature,
TopP: m.TopP,
TopK: m.TopK,
Timeout: m.Timeout,
TransportConfig: transportConfigFromBase(m.BaseModel, m.Timeout),
Model: modelName,
BaseUrl: m.BaseUrl,
MaxTokens: m.MaxTokens,
Temperature: m.Temperature,
TopP: m.TopP,
TopK: m.TopK,
}
return models.NewAnthropicModelWithLogger(cfg, log)

Expand All @@ -257,15 +258,18 @@ func CreateLLM(ctx context.Context, m adk.Model, log logr.Logger) (adkmodel.LLM,
if baseURL == "" {
baseURL = "http://localhost:11434"
}
baseURL = strings.TrimSuffix(baseURL, "/")
if !strings.HasSuffix(baseURL, "/v1") {
baseURL += "/v1"
}
modelName := m.Model
if modelName == "" {
modelName = DefaultOllamaModel
}
return models.NewOpenAICompatibleModelWithLogger(baseURL, modelName, extractHeaders(m.Headers), "", log)
// Create OllamaConfig with native SDK support for Ollama-specific options
cfg := &models.OllamaConfig{
TransportConfig: transportConfigFromBase(m.BaseModel, nil),
Model: modelName,
Host: baseURL,
Options: m.Options,
}
return models.NewOllamaModelWithLogger(cfg, log)

case *adk.Bedrock:
region := m.Region
Expand All @@ -279,11 +283,13 @@ func CreateLLM(ctx context.Context, m adk.Model, log logr.Logger) (adkmodel.LLM,
if modelName == "" {
return nil, fmt.Errorf("bedrock requires a model name (e.g. anthropic.claude-3-sonnet-20240229-v1:0)")
}
cfg := &models.AnthropicConfig{
Model: modelName,
Headers: extractHeaders(m.Headers),
// Use Bedrock Converse API for ALL models (including Anthropic)
cfg := &models.BedrockConfig{
TransportConfig: transportConfigFromBase(m.BaseModel, nil),
Model: modelName,
Region: region,
}
return models.NewAnthropicBedrockModelWithLogger(ctx, cfg, region, log)
return models.NewBedrockModelWithLogger(ctx, cfg, log)

case *adk.GeminiAnthropic:
// GeminiAnthropic = Claude models accessed through Google Cloud Vertex AI.
Expand All @@ -301,8 +307,8 @@ func CreateLLM(ctx context.Context, m adk.Model, log logr.Logger) (adkmodel.LLM,
modelName = DefaultAnthropicModel
}
cfg := &models.AnthropicConfig{
Model: modelName,
Headers: extractHeaders(m.Headers),
TransportConfig: transportConfigFromBase(m.BaseModel, nil),
Model: modelName,
}
return models.NewAnthropicVertexAIModelWithLogger(ctx, cfg, region, project, log)

Expand All @@ -311,6 +317,18 @@ func CreateLLM(ctx context.Context, m adk.Model, log logr.Logger) (adkmodel.LLM,
}
}

// transportConfigFromBase builds a TransportConfig from the shared BaseModel fields.
func transportConfigFromBase(b adk.BaseModel, timeout *int) models.TransportConfig {
return models.TransportConfig{
Headers: extractHeaders(b.Headers),
TLSInsecureSkipVerify: b.TLSInsecureSkipVerify,
TLSCACertPath: b.TLSCACertPath,
TLSDisableSystemCAs: b.TLSDisableSystemCAs,
APIKeyPassthrough: b.APIKeyPassthrough,
Timeout: timeout,
}
}

// extractHeaders returns an empty map if nil, the original map otherwise.
func extractHeaders(headers map[string]string) map[string]string {
if headers == nil {
Expand Down
42 changes: 42 additions & 0 deletions go/adk/pkg/agent/approval.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,52 @@ package agent
import (
"fmt"

"google.golang.org/adk/agent"
"google.golang.org/adk/agent/llmagent"
adkmodel "google.golang.org/adk/model"
"google.golang.org/adk/tool"
"google.golang.org/genai"
)

// stripConfirmationPartsCallback is a BeforeModelCallback that removes
// adk_request_confirmation FunctionCall and FunctionResponse parts from the
// LLM request before it reaches any model provider. These are synthetic ADK
// HITL events the LLM never produced and does not need to reason about.
// The session still stores them so ADK's resume machinery can find them.
func MakeStripConfirmationPartsCallback() llmagent.BeforeModelCallback {
return func(_ agent.CallbackContext, req *adkmodel.LLMRequest) (*adkmodel.LLMResponse, error) {
out := make([]*genai.Content, 0, len(req.Contents))
for _, c := range req.Contents {
if c == nil {
continue
}
filtered := make([]*genai.Part, 0, len(c.Parts))
for _, p := range c.Parts {
if p == nil {
continue
}
if p.FunctionCall != nil && p.FunctionCall.Name == "adk_request_confirmation" {
continue
}
if p.FunctionResponse != nil && p.FunctionResponse.Name == "adk_request_confirmation" {
continue
}
filtered = append(filtered, p)
}
if len(filtered) == 0 {
continue
}
newContent := &genai.Content{
Role: c.Role,
Parts: filtered,
}
out = append(out, newContent)
}
req.Contents = out
return nil, nil
}
}

// MakeApprovalCallback creates a BeforeToolCallback that gates execution of
// tools in the approval set behind request_confirmation / ToolConfirmation.
// Port of kagent-adk/src/kagent/adk/_approval.py:make_approval_callback().
Expand Down
4 changes: 3 additions & 1 deletion go/adk/pkg/agent/createllm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,10 @@ func TestAgent_OpenAI_WithParams(t *testing.T) {
}

func TestAgent_Ollama(t *testing.T) {
// mockllm does not support the native Ollama /api/chat endpoint,
// so we test with an OpenAI-compatible model pointing at the mock.
baseURL := startMock(t, "testdata/mock_openai.json")
t.Setenv("OLLAMA_API_BASE", baseURL)
t.Setenv("OPENAI_API_KEY", "ollama") // placeholder, Ollama ignores it

cfg := loadConfig(t, "testdata/config_ollama.json", baseURL)
text := runAgent(t, cfg, "What is 2+2?")
Expand Down
5 changes: 3 additions & 2 deletions go/adk/pkg/agent/testdata/config_ollama.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
{
"model": {
"type": "ollama",
"model": "llama3.2"
"type": "openai",
"model": "llama3.2",
"base_url": "{{BASE_URL}}/v1"
},
"description": "test",
"instruction": "Answer concisely."
Expand Down
Loading
Loading