feat: add inference provider specific extra_headers fields to be used at runtime#5217
feat: add inference provider specific extra_headers fields to be used at runtime#5217NickGagan wants to merge 5 commits intollamastack:mainfrom
Conversation
|
✅ Recordings committed successfully Recordings from the integration tests have been committed to this PR. |
| except (AttributeError, ValueError): | ||
| return {} | ||
| if provider_data and (headers := getattr(provider_data, self.provider_data_extra_headers_field, None)): | ||
| return dict(headers) |
There was a problem hiding this comment.
i think this needs a header blocklist before it merges. the openai SDK builds default_headers as {...sdk_defaults, **self.auth_headers, **self._custom_headers}, so a caller who passes {"Authorization": "Bearer attacker-token"} in openai_extra_headers silently overwrites the provider's configured API key on every request — the runtime value wins on collision.
PR #5100 had a BLOCKED_HEADERS frozenset for exactly this.
the description says "this might be OK since these are explicit per-provider fields" , but the fields are still populated from caller-controlled JSON, so I would imagine a single compromised client token can use this to override auth on every request that hits that provider.
| try: | ||
| provider_data = self.get_request_provider_data() | ||
| except (AttributeError, ValueError): | ||
| return {} |
There was a problem hiding this comment.
this swallows ValueError silently , get_request_provider_data() raises ValueError when provider data is malformed, so the caller gets no headers and no error. the API-key path in the same file (_get_api_key_from_config_or_provider_data()) calls the same method with no try/except and lets errors surface. worth either removing the guard to match that pattern, or at minimum adding a log warning so operators can diagnose missing headers.
| endpoint = self.get_base_url().replace("/v1", "") + "/rerank" # TODO: find a better solution | ||
| async with session.post(endpoint, headers={}, json=payload) as response: | ||
| async with session.post( | ||
| endpoint, headers=self._get_provider_data_extra_headers(), json=payload |
There was a problem hiding this comment.
the rerank aiohttp path forwards extra headers but still doesn't include the vllm_api_token, so a vLLM endpoint that requires auth gets the custom headers but no Authorization. the inference paths get the token via self.client (AsyncOpenAI), but aiohttp doesn't go through that path. worth either adding the token here or leaving a comment explaining the gap so it doesn't surprise the next person who adds auth to the rerank endpoint.
| assert inference_adapter.client.api_key == api_key | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( |
There was a problem hiding this comment.
the existing test_openai_provider_data_used tests sequential requests to prove API keys don't leak. worth adding a similar test for extra headers using asyncio.gather , two concurrent tasks with different vllm_extra_headers values asserting neither sees the other's headers. the ContextVar scoping should handle this correctly, but since the whole file exists to guard against this class of bug it's worth having the test.
cdoern
left a comment
There was a problem hiding this comment.
one comment similar to what Sumanth said, otherwise lgtm.
| try: | ||
| provider_data = self.get_request_provider_data() | ||
| except (AttributeError, ValueError): | ||
| return {} |
There was a problem hiding this comment.
is there a case here where we want to raise this err?
What does this PR do?
Related issue: #5077
Alternative to PR: #5100 (Working with @Lucifergene on this)
Based off of feedback from @mattf: #5100 (comment)
For each inferencing provider, similar to API keys, there can be an additional
provider_data_extra_headers_field(e.g. vllm_extra_headers) field that allows you to forward headers at runtime.This PR only adds it for OpenAI and vLLM, but happy to add it for the rest if this is the approach we want to take.
For the security concern, this might be OK since each of these are explicit per-provider fields: #4607.
Test Plan
Automated
Unit tests in
tests/unit/providers/inference/test_inference_client_caching.pyverify thatvllm_extra_headersandopenai_extra_headersfromX-LlamaStack-Provider-Dataare forwardedas
default_headerson the outboundAsyncOpenAIclient.Manual (end-to-end header forwarding)
1. Start a mock OpenAI server that logs incoming headers
2. Add to your run config
3. Send requests with and without extra headers
4. Confirm headers in mock server output
The mock server prints received headers to stdout for each request. Check its terminal output:
With
vllm_extra_headers(first request):Without
vllm_extra_headers(second request):The presence of
X-Test-Header: confirmedin the first request and its absence in the second confirms end-to-end header forwarding is working correctly.