Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 83 additions & 0 deletions clarifai/runners/models/vllm_openai_class.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,66 @@
import re
import threading
import time
from typing import Iterator

import httpx
from clarifai_protocol import get_item_id, register_item_abort_callback

from clarifai.runners.models.openai_class import OpenAIModelClass
from clarifai.utils.logging import logger


class VLLMMetricsPoller:
"""Polls vLLM /metrics in background; caches kv_cache_usage and waiting count.

Start in load_model() to enable admission control:

self._metrics_poller = VLLMMetricsPoller(f"http://{host}:{port}")
"""

KV_CACHE_REJECT_THRESHOLD = 0.95
MAX_WAITING_REQUESTS = 10

def __init__(self, base_url: str, poll_interval: float = 0.5):
self.base_url = base_url
self.poll_interval = poll_interval
self._kv_cache = 0.0
self._waiting = 0
self._lock = threading.Lock()
self._last_success = time.time()
threading.Thread(target=self._poll_loop, daemon=True).start()

def _poll_loop(self):
while True:
try:
resp = httpx.get(f"{self.base_url}/metrics", timeout=1.0)
if resp.status_code == 200:
text = resp.text
waiting = int(
self._parse(text, r'vllm:num_requests_waiting\{[^}]*\}\s+([\d.]+)')
)
kv_cache = self._parse(text, r'vllm:kv_cache_usage_perc\{[^}]*\}\s+([\d.]+)')
with self._lock:
self._waiting = waiting
self._kv_cache = kv_cache
self._last_success = time.time()
except Exception as e:
logger.warning(f"[VLLMMetricsPoller] Poll failed: {e}")
time.sleep(self.poll_interval)

def _parse(self, text: str, pattern: str) -> float:
m = re.search(pattern, text)
return float(m.group(1)) if m else 0.0

def snapshot(self):
"""Return (kv_cache, waiting) atomically."""
with self._lock:
return self._kv_cache, self._waiting

@property
def is_stale(self) -> bool:
with self._lock:
return time.time() - self._last_success > 5.0


class VLLMCancellationHandler:
Expand Down Expand Up @@ -91,6 +147,33 @@ def generate(self, prompt, ...) -> Iterator[str]:

server = None
cancellation_handler = None
_metrics_poller = None

@property
def admission_control_backoff(self) -> float:
"""Seconds to wait before retrying after admission rejection. Override to customize."""
return 1.0

def check_admission(self) -> bool:
"""Fail-open: reject only when KV cache is saturated or waiting queue is too deep.

Called by the runner before dispatching a request. Returns True to admit, False to reject.
Admission control is disabled (always admits) when _metrics_poller is not set or is stale.
Enable by setting self._metrics_poller in load_model():

self._metrics_poller = VLLMMetricsPoller(f"http://{host}:{port}")
"""
if self._metrics_poller is None or self._metrics_poller.is_stale:
return True
p = self._metrics_poller
kv_cache, waiting = p.snapshot()
if kv_cache > p.KV_CACHE_REJECT_THRESHOLD:
logger.info(f"[AdmissionControl] REJECT kv_cache={kv_cache:.2%}")
return False
if waiting > p.MAX_WAITING_REQUESTS:
logger.info(f"[AdmissionControl] REJECT waiting={waiting}")
return False
return True

def handle_liveness_probe(self) -> bool:
if self.server is None:
Expand Down
58 changes: 57 additions & 1 deletion tests/runners/test_vllm_openai_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@
import pytest

from clarifai.runners.models.dummy_openai_model import MockOpenAIClient
from clarifai.runners.models.vllm_openai_class import VLLMCancellationHandler, VLLMOpenAIModelClass
from clarifai.runners.models.vllm_openai_class import (
VLLMCancellationHandler,
VLLMMetricsPoller,
VLLMOpenAIModelClass,
)


# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -228,3 +232,55 @@ def test_invalid_endpoint_raises_value_error(self):
with patch("clarifai.runners.models.vllm_openai_class.get_item_id", side_effect=Exception):
with pytest.raises(ValueError, match="Only"):
list(model.openai_stream_transport(request))


# ---------------------------------------------------------------------------
# VLLMOpenAIModelClass — admission control
# ---------------------------------------------------------------------------
class TestAdmissionControl:
def test_check_admission_no_poller_admits(self):
"""No metrics poller set → fail-open, always admit."""
model = DummyVLLMModel()
assert model._metrics_poller is None
assert model.check_admission() is True

def test_check_admission_stale_poller_admits(self):
"""Stale poller → fail-open."""
model = DummyVLLMModel()
mock_poller = MagicMock(spec=VLLMMetricsPoller)
mock_poller.is_stale = True
model._metrics_poller = mock_poller
assert model.check_admission() is True

def test_check_admission_kv_cache_over_threshold_rejects(self):
model = DummyVLLMModel()
mock_poller = MagicMock(spec=VLLMMetricsPoller)
mock_poller.is_stale = False
mock_poller.KV_CACHE_REJECT_THRESHOLD = 0.95
mock_poller.MAX_WAITING_REQUESTS = 10
mock_poller.snapshot.return_value = (0.96, 0) # kv_cache above threshold
model._metrics_poller = mock_poller
assert model.check_admission() is False

def test_check_admission_waiting_over_limit_rejects(self):
model = DummyVLLMModel()
mock_poller = MagicMock(spec=VLLMMetricsPoller)
mock_poller.is_stale = False
mock_poller.KV_CACHE_REJECT_THRESHOLD = 0.95
mock_poller.MAX_WAITING_REQUESTS = 10
mock_poller.snapshot.return_value = (0.50, 11) # waiting above limit
model._metrics_poller = mock_poller
assert model.check_admission() is False

def test_check_admission_healthy_admits(self):
model = DummyVLLMModel()
mock_poller = MagicMock(spec=VLLMMetricsPoller)
mock_poller.is_stale = False
mock_poller.KV_CACHE_REJECT_THRESHOLD = 0.95
mock_poller.MAX_WAITING_REQUESTS = 10
mock_poller.snapshot.return_value = (0.50, 3)
model._metrics_poller = mock_poller
assert model.check_admission() is True

def test_admission_control_backoff_default(self):
assert DummyVLLMModel().admission_control_backoff == 1.0
Loading