From 89aa1b893e5862345b98643bcbc1be4582b9ceab Mon Sep 17 00:00:00 2001 From: Christine Yu Date: Tue, 24 Mar 2026 15:48:12 -0400 Subject: [PATCH] add basic refactoring for vllm class admin control --- clarifai/runners/models/vllm_openai_class.py | 83 ++++++++++++++++++++ tests/runners/test_vllm_openai_class.py | 58 +++++++++++++- 2 files changed, 140 insertions(+), 1 deletion(-) diff --git a/clarifai/runners/models/vllm_openai_class.py b/clarifai/runners/models/vllm_openai_class.py index b440e7c1..7f0432ab 100644 --- a/clarifai/runners/models/vllm_openai_class.py +++ b/clarifai/runners/models/vllm_openai_class.py @@ -1,10 +1,66 @@ +import re import threading +import time from typing import Iterator import httpx from clarifai_protocol import get_item_id, register_item_abort_callback from clarifai.runners.models.openai_class import OpenAIModelClass +from clarifai.utils.logging import logger + + +class VLLMMetricsPoller: + """Polls vLLM /metrics in background; caches kv_cache_usage and waiting count. + + Start in load_model() to enable admission control: + + self._metrics_poller = VLLMMetricsPoller(f"http://{host}:{port}") + """ + + KV_CACHE_REJECT_THRESHOLD = 0.95 + MAX_WAITING_REQUESTS = 10 + + def __init__(self, base_url: str, poll_interval: float = 0.5): + self.base_url = base_url + self.poll_interval = poll_interval + self._kv_cache = 0.0 + self._waiting = 0 + self._lock = threading.Lock() + self._last_success = time.time() + threading.Thread(target=self._poll_loop, daemon=True).start() + + def _poll_loop(self): + while True: + try: + resp = httpx.get(f"{self.base_url}/metrics", timeout=1.0) + if resp.status_code == 200: + text = resp.text + waiting = int( + self._parse(text, r'vllm:num_requests_waiting\{[^}]*\}\s+([\d.]+)') + ) + kv_cache = self._parse(text, r'vllm:kv_cache_usage_perc\{[^}]*\}\s+([\d.]+)') + with self._lock: + self._waiting = waiting + self._kv_cache = kv_cache + self._last_success = time.time() + except Exception as e: + logger.warning(f"[VLLMMetricsPoller] Poll failed: {e}") + time.sleep(self.poll_interval) + + def _parse(self, text: str, pattern: str) -> float: + m = re.search(pattern, text) + return float(m.group(1)) if m else 0.0 + + def snapshot(self): + """Return (kv_cache, waiting) atomically.""" + with self._lock: + return self._kv_cache, self._waiting + + @property + def is_stale(self) -> bool: + with self._lock: + return time.time() - self._last_success > 5.0 class VLLMCancellationHandler: @@ -91,6 +147,33 @@ def generate(self, prompt, ...) -> Iterator[str]: server = None cancellation_handler = None + _metrics_poller = None + + @property + def admission_control_backoff(self) -> float: + """Seconds to wait before retrying after admission rejection. Override to customize.""" + return 1.0 + + def check_admission(self) -> bool: + """Fail-open: reject only when KV cache is saturated or waiting queue is too deep. + + Called by the runner before dispatching a request. Returns True to admit, False to reject. + Admission control is disabled (always admits) when _metrics_poller is not set or is stale. + Enable by setting self._metrics_poller in load_model(): + + self._metrics_poller = VLLMMetricsPoller(f"http://{host}:{port}") + """ + if self._metrics_poller is None or self._metrics_poller.is_stale: + return True + p = self._metrics_poller + kv_cache, waiting = p.snapshot() + if kv_cache > p.KV_CACHE_REJECT_THRESHOLD: + logger.info(f"[AdmissionControl] REJECT kv_cache={kv_cache:.2%}") + return False + if waiting > p.MAX_WAITING_REQUESTS: + logger.info(f"[AdmissionControl] REJECT waiting={waiting}") + return False + return True def handle_liveness_probe(self) -> bool: if self.server is None: diff --git a/tests/runners/test_vllm_openai_class.py b/tests/runners/test_vllm_openai_class.py index 58fdac45..c0c7d145 100644 --- a/tests/runners/test_vllm_openai_class.py +++ b/tests/runners/test_vllm_openai_class.py @@ -7,7 +7,11 @@ import pytest from clarifai.runners.models.dummy_openai_model import MockOpenAIClient -from clarifai.runners.models.vllm_openai_class import VLLMCancellationHandler, VLLMOpenAIModelClass +from clarifai.runners.models.vllm_openai_class import ( + VLLMCancellationHandler, + VLLMMetricsPoller, + VLLMOpenAIModelClass, +) # --------------------------------------------------------------------------- @@ -228,3 +232,55 @@ def test_invalid_endpoint_raises_value_error(self): with patch("clarifai.runners.models.vllm_openai_class.get_item_id", side_effect=Exception): with pytest.raises(ValueError, match="Only"): list(model.openai_stream_transport(request)) + + +# --------------------------------------------------------------------------- +# VLLMOpenAIModelClass — admission control +# --------------------------------------------------------------------------- +class TestAdmissionControl: + def test_check_admission_no_poller_admits(self): + """No metrics poller set → fail-open, always admit.""" + model = DummyVLLMModel() + assert model._metrics_poller is None + assert model.check_admission() is True + + def test_check_admission_stale_poller_admits(self): + """Stale poller → fail-open.""" + model = DummyVLLMModel() + mock_poller = MagicMock(spec=VLLMMetricsPoller) + mock_poller.is_stale = True + model._metrics_poller = mock_poller + assert model.check_admission() is True + + def test_check_admission_kv_cache_over_threshold_rejects(self): + model = DummyVLLMModel() + mock_poller = MagicMock(spec=VLLMMetricsPoller) + mock_poller.is_stale = False + mock_poller.KV_CACHE_REJECT_THRESHOLD = 0.95 + mock_poller.MAX_WAITING_REQUESTS = 10 + mock_poller.snapshot.return_value = (0.96, 0) # kv_cache above threshold + model._metrics_poller = mock_poller + assert model.check_admission() is False + + def test_check_admission_waiting_over_limit_rejects(self): + model = DummyVLLMModel() + mock_poller = MagicMock(spec=VLLMMetricsPoller) + mock_poller.is_stale = False + mock_poller.KV_CACHE_REJECT_THRESHOLD = 0.95 + mock_poller.MAX_WAITING_REQUESTS = 10 + mock_poller.snapshot.return_value = (0.50, 11) # waiting above limit + model._metrics_poller = mock_poller + assert model.check_admission() is False + + def test_check_admission_healthy_admits(self): + model = DummyVLLMModel() + mock_poller = MagicMock(spec=VLLMMetricsPoller) + mock_poller.is_stale = False + mock_poller.KV_CACHE_REJECT_THRESHOLD = 0.95 + mock_poller.MAX_WAITING_REQUESTS = 10 + mock_poller.snapshot.return_value = (0.50, 3) + model._metrics_poller = mock_poller + assert model.check_admission() is True + + def test_admission_control_backoff_default(self): + assert DummyVLLMModel().admission_control_backoff == 1.0