From 6abcae8de9933a06f68cd15019e3fc8fb43f5715 Mon Sep 17 00:00:00 2001 From: ishaanxgupta <124028055+ishaanxgupta@users.noreply.github.com> Date: Wed, 25 Feb 2026 18:25:15 +0000 Subject: [PATCH] feat: Add rate limiting middleware to server.py --- server.py | 42 +++++++++++++++++++++++++++++++++++++++++- 1 file changed, 41 insertions(+), 1 deletion(-) diff --git a/server.py b/server.py index 335106a..2958ce5 100644 --- a/server.py +++ b/server.py @@ -18,11 +18,12 @@ from pathlib import Path from typing import Any, Dict, List -from fastapi import FastAPI +from fastapi import FastAPI, Request from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import FileResponse, JSONResponse from fastapi.staticfiles import StaticFiles from pydantic import BaseModel +from starlette.middleware.base import BaseHTTPMiddleware # ── Project root setup ──────────────────────────────────────────── import sys @@ -46,12 +47,49 @@ from src.pipelines.ingest import IngestPipeline from src.pipelines.retrieval import RetrievalPipeline +from src.config import settings # ═══════════════════════════════════════════════════════════════════ # Log capture — collects pipeline log messages during a run # ═══════════════════════════════════════════════════════════════════ + +# ═══════════════════════════════════════════════════════════════════ +# Rate Limiting Middleware +# ═══════════════════════════════════════════════════════════════════ + +class RateLimitMiddleware(BaseHTTPMiddleware): + def __init__(self, app): + super().__init__(app) + self.rate_limit_records: Dict[str, List[float]] = {} + + async def dispatch(self, request: Request, call_next): + client_ip = request.client.host if request.client else "unknown" + current_time = time.time() + + if client_ip not in self.rate_limit_records: + self.rate_limit_records[client_ip] = [] + + # Filter timestamps older than 60 seconds + self.rate_limit_records[client_ip] = [ + t for t in self.rate_limit_records[client_ip] + if current_time - t < 60 + ] + + if len(self.rate_limit_records[client_ip]) >= settings.rate_limit: + return JSONResponse( + status_code=429, + content={ + "error": "Too many requests", + "detail": f"Rate limit exceeded: {settings.rate_limit} per minute" + }, + headers={"Retry-After": "60"} + ) + + self.rate_limit_records[client_ip].append(current_time) + return await call_next(request) + class StepCapture(logging.Handler): """Captures log records into a list of step dicts.""" @@ -123,6 +161,8 @@ async def lifespan(app: FastAPI): app = FastAPI(title="Xmem Test Frontend", lifespan=lifespan) +app.add_middleware(RateLimitMiddleware) + app.add_middleware( CORSMiddleware, allow_origins=["*"],