From 735855e9cde29ec28aadaf12a92ffdadf09e3d54 Mon Sep 17 00:00:00 2001 From: Alexander Osipenko <11722602+subpath@users.noreply.github.com> Date: Thu, 19 Mar 2026 17:25:42 +0100 Subject: [PATCH 1/7] add endpoint to get user count --- .../core/pg_services/litellm_pg_service.py | 38 +++++++++++++++++++ src/mlpa/core/routers/user/user.py | 8 ++++ src/tests/integration/test_user_management.py | 38 +++++++++++++++++++ src/tests/mocks.py | 17 +++++++++ 4 files changed, 101 insertions(+) diff --git a/src/mlpa/core/pg_services/litellm_pg_service.py b/src/mlpa/core/pg_services/litellm_pg_service.py index bdf6812..30c9ae7 100644 --- a/src/mlpa/core/pg_services/litellm_pg_service.py +++ b/src/mlpa/core/pg_services/litellm_pg_service.py @@ -98,6 +98,44 @@ async def list_users(self, limit: int = 50, offset: int = 0) -> dict: status_code=500, detail={"error": "Error listing users"} ) + async def count_users_by_service_type(self) -> dict: + """ + Return total user counts grouped by service_type. + + LiteLLM stores users by `user_id` which is formatted as: + `{base_user_id}:{service_type}`. + """ + try: + rows = await self.pg.fetch( + """ + SELECT + split_part(user_id, ':', 2) AS service_type, + COUNT(*)::int AS total_users + FROM "LiteLLM_EndUserTable" + WHERE position(':' in user_id) > 0 + GROUP BY service_type + """ + ) + + service_type_counts: dict[str, int] = {} + for row in rows: + service_type = row.get("service_type") + if not service_type: + continue + service_type_counts[str(service_type)] = int(row.get("total_users", 0)) + + total_users = sum(service_type_counts.values()) + return { + "service_type_counts": service_type_counts, + "total_users": int(total_users), + } + except Exception as e: + logger.error(f"Error counting users by service type: {e}") + raise HTTPException( + status_code=500, + detail={"error": "Error counting users by service type"}, + ) + async def create_budget(self): """ Create end user budgets from configuration. diff --git a/src/mlpa/core/routers/user/user.py b/src/mlpa/core/routers/user/user.py index 73ee5e4..04a298d 100644 --- a/src/mlpa/core/routers/user/user.py +++ b/src/mlpa/core/routers/user/user.py @@ -37,6 +37,14 @@ async def list_users( return await litellm_pg.list_users(limit=limit, offset=offset) +@router.get("/counts-by-service-type", tags=["User Management"]) +async def count_users_by_service_type( + _: Annotated[None, Depends(require_master_key)] = None, +): + """Get total user counts grouped by service_type.""" + return await litellm_pg.count_users_by_service_type() + + @router.get("/{user_id}", tags=["User"]) async def user_info(user_id: str): if not user_id or user_id.strip() == "": diff --git a/src/tests/integration/test_user_management.py b/src/tests/integration/test_user_management.py index 7f21d87..cd45e85 100644 --- a/src/tests/integration/test_user_management.py +++ b/src/tests/integration/test_user_management.py @@ -268,3 +268,41 @@ def test_list_users_empty_result(mocked_client_integration, mocker): data = response.json() assert data["total"] == 0 assert len(data["users"]) == 0 + + +def test_count_users_by_service_type_success(mocked_client_integration, mocker): + from tests.mocks import MockLiteLLMPGService + + mock_litellm_pg = MockLiteLLMPGService() + mock_litellm_pg.store_user( + "user1:ai", {"user_id": "user1:ai", "blocked": False, "alias": None} + ) + mock_litellm_pg.store_user( + "user2:ai", {"user_id": "user2:ai", "blocked": False, "alias": None} + ) + mock_litellm_pg.store_user( + "user3:s2s", + {"user_id": "user3:s2s", "blocked": False, "alias": None}, + ) + + mocker.patch("mlpa.core.routers.user.user.litellm_pg", mock_litellm_pg) + + response = mocked_client_integration.get( + "/user/counts-by-service-type", + headers={"master_key": f"Bearer {env.MASTER_KEY}"}, + ) + assert response.status_code == 200 + data = response.json() + assert data["total_users"] == 3 + assert data["service_type_counts"]["ai"] == 2 + assert data["service_type_counts"]["s2s"] == 1 + + +def test_count_users_by_service_type_unauthorized(mocked_client_integration): + response = mocked_client_integration.get( + "/user/counts-by-service-type", + headers={"master_key": "Bearer invalid-key"}, + ) + + assert response.status_code == 401 + assert "Unauthorized" in str(response.json()) diff --git a/src/tests/mocks.py b/src/tests/mocks.py index 25529f5..4a3b4cc 100644 --- a/src/tests/mocks.py +++ b/src/tests/mocks.py @@ -205,6 +205,23 @@ async def list_users(self, limit: int = 50, offset: int = 0): "offset": offset, } + async def count_users_by_service_type(self) -> dict: + """Mock count_users_by_service_type grouped by service_type.""" + service_type_counts: dict[str, int] = {} + for user_id in self.users.keys(): + service_type = user_id.split(":")[1] if ":" in user_id else "" + if not service_type: + continue + service_type_counts[service_type] = ( + service_type_counts.get(service_type, 0) + 1 + ) + + total_users = sum(service_type_counts.values()) + return { + "service_type_counts": service_type_counts, + "total_users": total_users, + } + class MockFxAService: def __init__(self, client_id: str, client_secret: str, fxa_url: str): From fec21afe431d8d9d972e192534c71394c4cd5535 Mon Sep 17 00:00:00 2001 From: Alexander Osipenko <11722602+subpath@users.noreply.github.com> Date: Thu, 19 Mar 2026 17:26:15 +0100 Subject: [PATCH 2/7] update docs --- docs/index.html | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/docs/index.html b/docs/index.html index 4f19ce7..54b7b34 100644 --- a/docs/index.html +++ b/docs/index.html @@ -373,7 +373,7 @@
-nullUpdate a user's budget tier by service type (e.g. ai-dev for higher limits).
| user_id required | string (User Id) |
| master-key required | string (Master Key) |
| service_type required | string (Service Type) |
{- "service_type": "string"
}null