Skip to content

feat: add inference provider specific extra_headers fields to be used at runtime#5217

Open
NickGagan wants to merge 5 commits intollamastack:mainfrom
NickGagan:vllm_api_key_ng
Open

feat: add inference provider specific extra_headers fields to be used at runtime#5217
NickGagan wants to merge 5 commits intollamastack:mainfrom
NickGagan:vllm_api_key_ng

Conversation

@NickGagan
Copy link
Contributor

@NickGagan NickGagan commented Mar 19, 2026

What does this PR do?

Related issue: #5077
Alternative to PR: #5100 (Working with @Lucifergene on this)
Based off of feedback from @mattf: #5100 (comment)

For each inferencing provider, similar to API keys, there can be an additional provider_data_extra_headers_field (e.g. vllm_extra_headers) field that allows you to forward headers at runtime.

This PR only adds it for OpenAI and vLLM, but happy to add it for the rest if this is the approach we want to take.

For the security concern, this might be OK since each of these are explicit per-provider fields: #4607.

Test Plan

Automated

Unit tests in tests/unit/providers/inference/test_inference_client_caching.py verify that
vllm_extra_headers and openai_extra_headers from X-LlamaStack-Provider-Data are forwarded
as default_headers on the outbound AsyncOpenAI client.

Manual (end-to-end header forwarding)

1. Start a mock OpenAI server that logs incoming headers

# /tmp/mock_openai_server.py
import json
from http.server import BaseHTTPRequestHandler, HTTPServer

class Handler(BaseHTTPRequestHandler):
    def do_GET(self):
        self._respond({"object": "list", "data": [{"id": "test-model", "object": "model"}]})
    def do_POST(self):
        self.rfile.read(int(self.headers.get("Content-Length", 0)))
        print({k: v for k, v in self.headers.items()}, flush=True)
        self._respond({"id": "x", "object": "chat.completion", "created": 0, "model": "test-model",
            "choices": [{"index": 0, "message": {"role": "assistant", "content": "hi"}, "finish_reason": "stop"}],
            "usage": {"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2}})
    def _respond(self, data):
        b = json.dumps(data).encode()
        self.send_response(200); self.send_header("Content-Type", "application/json")
        self.send_header("Content-Length", len(b)); self.end_headers(); self.wfile.write(b)
    def log_message(self, *a): pass

HTTPServer(("localhost", 9001), Handler).serve_forever()
python3 /tmp/mock_openai_server.py

2. Add to your run config

version: "2"
image_name: extra-headers-test
apis:
  - inference
  - models
providers:
  inference:
    - provider_id: mock-vllm
      provider_type: remote::vllm
      config:
        base_url: http://localhost:9001/v1
        refresh_models: false

metadata_store:
  type: sqlite
  db_path: /tmp/extra_headers_test_registry.db

storage:
  backends:
    kv_default:
      type: kv_sqlite
      db_path: /tmp/extra_headers_test_kvstore.db
    sql_default:
      type: sql_sqlite
      db_path: /tmp/extra_headers_test_sqlstore.db

registered_resources:
  models:
    - provider_id: mock-vllm
      model_id: test-model
      provider_resource_id: test-model
      model_type: llm
      metadata:
        display_name: Test Model (Mock)
  shields: []
  vector_stores: []
  datasets: []
  scoring_fns: []
  benchmarks: []
  tool_groups: []

server:
  port: 8321

3. Send requests with and without extra headers

# With header — X-Test-Header should appear in server output
curl -X POST http://localhost:8321/v1/chat/completions \
  -H 'X-LlamaStack-Provider-Data: {"vllm_extra_headers": {"X-Test-Header": "confirmed"}}' \
  -H "Content-Type: application/json" \
  -d '{"model": "mock-vllm/test-model", "messages": [{"role": "user", "content": "hi"}]}'

# Without header — X-Test-Header should be absent
curl -X POST http://localhost:8321/v1/chat/completions \
  -H "Content-Type: application/json" \
  -d '{"model": "mock-vllm/test-model", "messages": [{"role": "user", "content": "hi"}]}'

4. Confirm headers in mock server output

The mock server prints received headers to stdout for each request. Check its terminal output:

With vllm_extra_headers (first request):

=== RECEIVED HEADERS ===
  ...
  X-Test-Header: confirmed
  ...

Without vllm_extra_headers (second request):

=== RECEIVED HEADERS ===
  ...
  # X-Test-Header should not appear
  ...

The presence of X-Test-Header: confirmed in the first request and its absence in the second confirms end-to-end header forwarding is working correctly.

@github-actions
Copy link
Contributor

github-actions bot commented Mar 19, 2026

Recordings committed successfully

Recordings from the integration tests have been committed to this PR.

View commit workflow

@NickGagan
Copy link
Contributor Author

FYI: @franciscojavierarceo

@NickGagan NickGagan changed the title feat: inference provider specific extra headers to be used at runtime feat: add inference provider specific extra_headers fields to be used at runtime Mar 19, 2026
except (AttributeError, ValueError):
return {}
if provider_data and (headers := getattr(provider_data, self.provider_data_extra_headers_field, None)):
return dict(headers)
Copy link
Contributor

@skamenan7 skamenan7 Mar 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think this needs a header blocklist before it merges. the openai SDK builds default_headers as {...sdk_defaults, **self.auth_headers, **self._custom_headers}, so a caller who passes {"Authorization": "Bearer attacker-token"} in openai_extra_headers silently overwrites the provider's configured API key on every request — the runtime value wins on collision.

PR #5100 had a BLOCKED_HEADERS frozenset for exactly this.

the description says "this might be OK since these are explicit per-provider fields" , but the fields are still populated from caller-controlled JSON, so I would imagine a single compromised client token can use this to override auth on every request that hits that provider.

try:
provider_data = self.get_request_provider_data()
except (AttributeError, ValueError):
return {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this swallows ValueError silently , get_request_provider_data() raises ValueError when provider data is malformed, so the caller gets no headers and no error. the API-key path in the same file (_get_api_key_from_config_or_provider_data()) calls the same method with no try/except and lets errors surface. worth either removing the guard to match that pattern, or at minimum adding a log warning so operators can diagnose missing headers.

endpoint = self.get_base_url().replace("/v1", "") + "/rerank" # TODO: find a better solution
async with session.post(endpoint, headers={}, json=payload) as response:
async with session.post(
endpoint, headers=self._get_provider_data_extra_headers(), json=payload
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the rerank aiohttp path forwards extra headers but still doesn't include the vllm_api_token, so a vLLM endpoint that requires auth gets the custom headers but no Authorization. the inference paths get the token via self.client (AsyncOpenAI), but aiohttp doesn't go through that path. worth either adding the token here or leaving a comment explaining the gap so it doesn't surprise the next person who adds auth to the rerank endpoint.

assert inference_adapter.client.api_key == api_key


@pytest.mark.parametrize(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the existing test_openai_provider_data_used tests sequential requests to prove API keys don't leak. worth adding a similar test for extra headers using asyncio.gather , two concurrent tasks with different vllm_extra_headers values asserting neither sees the other's headers. the ContextVar scoping should handle this correctly, but since the whole file exists to guard against this class of bug it's worth having the test.

Copy link
Collaborator

@cdoern cdoern left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one comment similar to what Sumanth said, otherwise lgtm.

try:
provider_data = self.get_request_provider_data()
except (AttributeError, ValueError):
return {}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a case here where we want to raise this err?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants