diff --git a/.github/workflows/docker-publish.yml b/.github/workflows/docker-publish.yml index 840f3c8..d946aae 100644 --- a/.github/workflows/docker-publish.yml +++ b/.github/workflows/docker-publish.yml @@ -14,7 +14,8 @@ jobs: build-and-push: runs-on: ubuntu-latest env: - IMAGE_NAME: ghcr.io/${{ github.repository }} + REGISTRY: ghcr.io + IMAGE_NAME: ${{ github.repository }} steps: - name: Checkout uses: actions/checkout@v4 @@ -22,7 +23,7 @@ jobs: - name: Set image name shell: bash run: | - echo "IMAGE_NAME=ghcr.io/${GITHUB_REPOSITORY,,}" >> "$GITHUB_ENV" + echo "IMAGE_NAME=${GITHUB_REPOSITORY,,}" >> "$GITHUB_ENV" - name: Set up QEMU uses: docker/setup-qemu-action@v3 @@ -33,9 +34,9 @@ jobs: - name: Login to GHCR uses: docker/login-action@v3 with: - registry: ghcr.io + registry: ${{ env.REGISTRY }} username: ${{ github.actor }} - password: ${{ secrets.GHCR_PAT }} + password: ${{ secrets.GITHUB_TOKEN }} - name: Build and push uses: docker/build-push-action@v5 diff --git a/app/core/config.py b/app/core/config.py index 3267862..47a6725 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -31,6 +31,13 @@ class Settings(BaseSettings): POSTGRES_HOST: str = "localhost" POSTGRES_PORT: int = 5432 + # Mobile auth/session defaults + MOBILE_SESSION_LIMIT: int = 3 + MOBILE_SESSION_TTL_SECONDS: int = 180 + MOBILE_SESSION_DAYS: int = 7 + # Admin list defaults + ADMIN_USERS_DEFAULT_LIMIT: int = 20 + ADMIN_USERS_MAX_LIMIT: int = 100 # Security jwt_secret: str jwt_algorithm: str = "HS256" diff --git a/app/deps/token_auth.py b/app/deps/token_auth.py index c5fe522..a7eff17 100644 --- a/app/deps/token_auth.py +++ b/app/deps/token_auth.py @@ -43,6 +43,8 @@ async def get_current_mobile_user( user = await container.auth_service.user_querier.get_user_by_id(id=session.user_id) if not user: raise HTTPException(status_code=401, detail="User not found") + if user.blocked: + raise HTTPException(status_code=403, detail="User is blocked") return MobileUserSchema( user_id=user.id, diff --git a/app/router/mobile/auth.py b/app/router/mobile/auth.py index 1784300..72c33db 100644 --- a/app/router/mobile/auth.py +++ b/app/router/mobile/auth.py @@ -4,7 +4,6 @@ from uuid import UUID from app.container import get_container, Container -from app.core.exceptions import AppException from app.deps.token_auth import MobileUserSchema, get_current_mobile_user from app.schema.request.mobile.auth import MobileAuthRequest, RefreshTokenRequest @@ -18,7 +17,6 @@ async def mobile_register_login( req: MobileAuthRequest, container: Container = Depends(get_container), ) -> MobileAuthResponse: - return await container.auth_service.mobile_register_login(container.redis, req) @@ -27,20 +25,18 @@ async def refresh_token( req: RefreshTokenRequest, container: Container = Depends(get_container), ) -> MobileAuthResponse: - return await container.auth_service.refresh_token(container.redis, req.refresh_token) @router.post("/logout") async def logout( container: Container = Depends(get_container), - User: MobileUserSchema = Depends(get_current_mobile_user), + current_user: MobileUserSchema = Depends(get_current_mobile_user), ) -> dict[str, str]: - return await container.auth_service.logout( container.redis, - str(User.user_id), - str(User.session_id), + str(current_user.user_id), + str(current_user.session_id), ) @@ -50,7 +46,6 @@ async def revoke_device( container: Container = Depends(get_container), current_user: MobileUserSchema = Depends(get_current_mobile_user), ) -> dict[str, str]: - await container.device_service.revoke_device( device_id=device_id, user_id=current_user.user_id, @@ -63,17 +58,14 @@ async def get_me( current_user: MobileUserSchema = Depends(get_current_mobile_user), container: Container = Depends(get_container), ) -> MeResponse: - - user = await container.auth_service.user_querier.get_user_by_id(id=current_user.user_id) - if user is None : - raise AppException.not_found("user not found") + user = await container.auth_service.get_user(user_id=current_user.user_id) devices, _ = await container.device_service.get_all_devices(current_user.user_id) device_list = [ DeviceSchema( id=d.id, - device_name=d.device_name or "uknown ", - device_type=d.device_type or "uknown ", + device_name=d.device_name or "unknown", + device_type=d.device_type or "unknown", totp_secret=d.totp_secret, ) for d in devices @@ -92,8 +84,6 @@ async def get_me( expires_at=sessions_objs.expires_at, ) - - return MeResponse( user=UserSchema(id=user.id, email=user.email), devices=device_list, diff --git a/app/router/web/__init__.py b/app/router/web/__init__.py index 396add6..c6f639e 100644 --- a/app/router/web/__init__.py +++ b/app/router/web/__init__.py @@ -2,8 +2,10 @@ from app.router.web.staff_users import router as staff_users_router from app.router.web.event import router as event_router from app.router.web.auth import router as auth_routes +from app.router.web.users import router as users_router router = APIRouter(prefix="/admin", tags=["admin"]) router.include_router(staff_users_router) router.include_router(event_router) router.include_router(auth_routes) +router.include_router(users_router) diff --git a/app/router/web/users.py b/app/router/web/users.py new file mode 100644 index 0000000..4866d5f --- /dev/null +++ b/app/router/web/users.py @@ -0,0 +1,108 @@ +from uuid import UUID + +from fastapi import APIRouter, Depends, Query, status + +from app.container import Container, get_container +from app.core.config import settings +from app.core.logger import logger +from app.deps.cookie_auth import get_current_staff_user +from app.schema.request.web.user import AdminUserCreateRequest, AdminUserUpdateRequest +from app.schema.response.web.user import AdminUserSchema, to_admin_user_schema +from db.generated.models import StaffUser + + +router = APIRouter(prefix="/users") + +@router.post("/", response_model=AdminUserSchema, status_code=status.HTTP_201_CREATED) +async def create_user( + req: AdminUserCreateRequest, + current_staff_user: StaffUser = Depends(get_current_staff_user), + container: Container = Depends(get_container), +) -> AdminUserSchema: + user = await container.auth_service.create_user( + email=req.email, + password=req.password, + display_name=req.display_name, + blocked=req.blocked, + ) + logger.info("admin %s created user %s", current_staff_user.id, user.id) + return to_admin_user_schema(user) + + +@router.get("/", response_model=list[AdminUserSchema]) +async def list_users( + limit: int = Query( + settings.ADMIN_USERS_DEFAULT_LIMIT, ge=1, le=settings.ADMIN_USERS_MAX_LIMIT + ), + offset: int = Query(0, ge=0), + current_staff_user: StaffUser = Depends(get_current_staff_user), + container: Container = Depends(get_container), +) -> list[AdminUserSchema]: + users = await container.auth_service.list_users(limit=limit, offset=offset) + return [to_admin_user_schema(user) for user in users] + + +@router.get("/{user_id}", response_model=AdminUserSchema) +async def get_user( + user_id: UUID, + current_staff_user: StaffUser = Depends(get_current_staff_user), + container: Container = Depends(get_container), +) -> AdminUserSchema: + user = await container.auth_service.get_user(user_id=user_id) + return to_admin_user_schema(user) + + +@router.put("/{user_id}", response_model=AdminUserSchema) +async def update_user( + user_id: UUID, + req: AdminUserUpdateRequest, + current_staff_user: StaffUser = Depends(get_current_staff_user), + container: Container = Depends(get_container), +) -> AdminUserSchema: + user = await container.auth_service.update_user( + user_id=user_id, + email=req.email, + display_name=req.display_name, + blocked=req.blocked, + ) + logger.info("admin %s updated user %s", current_staff_user.id, user_id) + return to_admin_user_schema(user) + + +@router.delete("/{user_id}", response_model=AdminUserSchema) +async def delete_user( + user_id: UUID, + current_staff_user: StaffUser = Depends(get_current_staff_user), + container: Container = Depends(get_container), +) -> AdminUserSchema: + user = await container.auth_service.delete_user( + redis=container.redis, + user_id=user_id, + ) + logger.info("admin %s deleted user %s", current_staff_user.id, user_id) + return to_admin_user_schema(user) + + +@router.post("/{user_id}/block", response_model=AdminUserSchema) +async def block_user( + user_id: UUID, + current_staff_user: StaffUser = Depends(get_current_staff_user), + container: Container = Depends(get_container), +) -> AdminUserSchema: + user = await container.auth_service.block_user( + redis=container.redis, + user_id=user_id, + ) + logger.info("admin %s blocked user %s", current_staff_user.id, user_id) + return to_admin_user_schema(user) + + +@router.post("/{user_id}/unblock", response_model=AdminUserSchema) +async def unblock_user( + user_id: UUID, + current_staff_user: StaffUser = Depends(get_current_staff_user), + container: Container = Depends(get_container), +) -> AdminUserSchema: + user = await container.auth_service.unblock_user(user_id=user_id) + logger.info("admin %s unblocked user %s", current_staff_user.id, user_id) + return to_admin_user_schema(user) diff --git a/app/schema/request/web/user.py b/app/schema/request/web/user.py new file mode 100644 index 0000000..2b41695 --- /dev/null +++ b/app/schema/request/web/user.py @@ -0,0 +1,15 @@ +from typing import Optional +from pydantic import BaseModel, EmailStr, Field + + +class AdminUserCreateRequest(BaseModel): + email: EmailStr + password: str = Field(..., min_length=8) + display_name: Optional[str] = None + blocked: bool = False + + +class AdminUserUpdateRequest(BaseModel): + email: Optional[EmailStr] = None + display_name: Optional[str] = None + blocked: Optional[bool] = None diff --git a/app/schema/response/web/user.py b/app/schema/response/web/user.py new file mode 100644 index 0000000..bd79627 --- /dev/null +++ b/app/schema/response/web/user.py @@ -0,0 +1,25 @@ +from datetime import datetime +from uuid import UUID + +from pydantic import BaseModel +from db.generated.models import User + + +class AdminUserSchema(BaseModel): + id: UUID + email: str + display_name: str | None + blocked: bool + created_at: datetime + updated_at: datetime + + +def to_admin_user_schema(user: User) -> AdminUserSchema: + return AdminUserSchema( + id=user.id, + email=user.email, + display_name=user.display_name, + blocked=user.blocked, + created_at=user.created_at, + updated_at=user.updated_at, + ) diff --git a/app/service/face_embedding.py b/app/service/face_embedding.py index f71c906..7d4f6d4 100644 --- a/app/service/face_embedding.py +++ b/app/service/face_embedding.py @@ -3,9 +3,9 @@ import asyncio from typing import List, Literal, Optional, Sequence, Tuple, TypedDict -import cv2 +import cv2 # type: ignore import numpy as np -from insightface.app import FaceAnalysis # type: ignore +from insightface.app import FaceAnalysis # type: ignore[import-untyped] from app.core.exceptions import AppException diff --git a/app/service/staff_user.py b/app/service/staff_user.py index 1da37f5..6241818 100644 --- a/app/service/staff_user.py +++ b/app/service/staff_user.py @@ -110,8 +110,8 @@ async def admin_login( ) -> WebAuthResponse: print("hello") staff: StaffUser | None = await self.staff_user_querier.get_staff_user_by_email(email=email) - if staff is None or not verify_password(password, staff.password): - logger.info(f'user:{staff.email}') # type: ignore + if staff is None or not verify_password(password, staff.password): + logger.info("admin login failed for email %s", email) raise AppException.unauthorized("Invalid email or password") diff --git a/app/service/users.py b/app/service/users.py index ecfaf91..e4801d4 100644 --- a/app/service/users.py +++ b/app/service/users.py @@ -2,7 +2,7 @@ import uuid from app.core import constant -from app.core.exceptions import AppException +from app.core.exceptions import AppException, DBException from app.core.securite import ( # EmbeddingCrypto, hash_password, @@ -12,6 +12,7 @@ decode_refresh_mobile_token, Get_expiry_time, ) +from app.core.config import settings from app.infra.redis import RedisClient from app.schema.request.mobile.auth import MobileAuthRequest @@ -28,8 +29,8 @@ class AuthService: user_querier: user_queries.AsyncQuerier device_querier: device_queries.AsyncQuerier session_querier: session_queries.AsyncQuerier - SESSION_LIMIT = 3 - REDIS_SESSION_TTL = 180 + SESSION_LIMIT = settings.MOBILE_SESSION_LIMIT + REDIS_SESSION_TTL = settings.MOBILE_SESSION_TTL_SECONDS def __init__( self, @@ -53,6 +54,8 @@ async def mobile_register_login( user: User | None = None if existing_user is not None: + if existing_user.blocked: + raise AppException.forbidden("User is blocked") if not verify_password(req.password, existing_user.hashed_password or ""): raise AppException.unauthorized("Invalid credentials") user = existing_user @@ -71,8 +74,6 @@ async def mobile_register_login( user_id: uuid.UUID = user.id session_key = constant.RedisKey.UserSessionByUser.value.format(user_id=user_id) - if await redis.exists(session_key): - raise AppException.forbidden("User already has an active session") session_count = await self.session_querier.count_user_sessions(user_id=user_id) if session_count and session_count >= AuthService.SESSION_LIMIT: @@ -84,7 +85,9 @@ async def mobile_register_login( raise AppException.forbidden("Maximum session limit reached") device_id = req.device_id - expires_at = datetime.now(timezone.utc) + timedelta(days=7) + expires_at = datetime.now(timezone.utc) + timedelta( + days=settings.MOBILE_SESSION_DAYS + ) device = await self.device_querier.create_device( arg=device_queries.CreateDeviceParams( @@ -145,15 +148,11 @@ async def refresh_token( if session.expires_at < datetime.now(timezone.utc): raise AppException.unauthorized("Session expired") - session_key = constant.RedisKey.UserSessionByUser.value.format( - user_id=session.user_id - ) - redis_session = await redis.get(session_key) - - if not redis_session or redis_session != session_id: - raise AppException.unauthorized("Session invalidated") - - await redis.expire(session_key, AuthService.REDIS_SESSION_TTL) + user = await self.user_querier.get_user_by_id(id=session.user_id) + if not user: + raise AppException.unauthorized("User not found") + if user.blocked: + raise AppException.forbidden("User is blocked") new_access_token = create_acces_mobile_token(session_id) new_refresh_token = create_refresh_mobile_token(session_id) @@ -211,13 +210,129 @@ async def validate_session( if session.expires_at < datetime.now(timezone.utc): return False - - session_key = constant.RedisKey.UserSessionByUser.value.format( - user_id=session.user_id - ) - redis_session = await redis.get(session_key) - - return redis_session == session_id + return True async def get_user_by_id(self, user_id: uuid.UUID) -> User | None: return await self.user_querier.get_user_by_id(id=user_id) + + async def create_user( + self, + *, + email: str, + password: str, + display_name: str | None = None, + blocked: bool = False, + ) -> User: + try: + hashed = hash_password(password) + user = await self.user_querier.create_user( + email=email, + hashed_password=hashed, + ) + if not user: + raise AppException.internal_error("Failed to create user") + + if display_name is not None or blocked: + updated = await self.user_querier.update_user( + email=user.email, + display_name=display_name, + blocked=blocked, + id=user.id, + ) + if not updated: + raise AppException.internal_error("Failed to update user") + return updated + + return user + except Exception as exc: + logger.error("Failed to create user: %s", exc) + raise DBException.handle(exc) + + async def get_user(self, *, user_id: uuid.UUID) -> User: + user = await self.user_querier.get_user_by_id(id=user_id) + if not user: + raise AppException.not_found("User not found") + return user + + async def list_users(self, *, limit: int, offset: int) -> list[User]: + try: + users: list[User] = [] + async for user in self.user_querier.list_users(limit=limit, offset=offset): + users.append(user) + return users + except Exception as exc: + logger.error("Failed to list users: %s", exc) + raise DBException.handle(exc) + + async def update_user( + self, + *, + user_id: uuid.UUID, + email: str | None = None, + display_name: str | None = None, + blocked: bool | None = None, + ) -> User: + try: + existing = await self.user_querier.get_user_by_id(id=user_id) + if not existing: + raise AppException.not_found("User not found") + + new_email = email if email is not None else existing.email + new_display_name = ( + display_name if display_name is not None else existing.display_name + ) + new_blocked = blocked if blocked is not None else existing.blocked + + user = await self.user_querier.update_user( + email=new_email, + display_name=new_display_name, + blocked=new_blocked, + id=user_id, + ) + if not user: + raise AppException.internal_error("Failed to update user") + return user + except Exception as exc: + logger.error("Failed to update user: %s", exc) + raise DBException.handle(exc) + + async def delete_user(self, *, redis: RedisClient, user_id: uuid.UUID) -> User: + try: + existing = await self.user_querier.get_user_by_id(id=user_id) + if not existing: + raise AppException.not_found("User not found") + await self.user_querier.delete_user(id=user_id) + session_key = constant.RedisKey.UserSessionByUser.value.format( + user_id=user_id + ) + await redis.delete(session_key) + return existing + except Exception as exc: + logger.error("Failed to delete user: %s", exc) + raise DBException.handle(exc) + + async def block_user(self, *, redis: RedisClient, user_id: uuid.UUID) -> User: + try: + user = await self.user_querier.set_user_blocked(blocked=True, id=user_id) + if not user: + raise AppException.not_found("User not found") + + session_key = constant.RedisKey.UserSessionByUser.value.format( + user_id=user_id + ) + await redis.delete(session_key) + + return user + except Exception as exc: + logger.error("Failed to block user: %s", exc) + raise DBException.handle(exc) + + async def unblock_user(self, *, user_id: uuid.UUID) -> User: + try: + user = await self.user_querier.set_user_blocked(blocked=False, id=user_id) + if not user: + raise AppException.not_found("User not found") + return user + except Exception as exc: + logger.error("Failed to unblock user: %s", exc) + raise DBException.handle(exc) diff --git a/db/generated/models.py b/db/generated/models.py index 28a9da1..482ab59 100644 --- a/db/generated/models.py +++ b/db/generated/models.py @@ -203,6 +203,7 @@ class User: updated_at: datetime.datetime display_name: Optional[str] face_embedding: Optional[Any] + blocked: bool deleted_at: Optional[datetime.datetime] diff --git a/db/generated/session.py b/db/generated/session.py index 1b8e026..bc7b427 100644 --- a/db/generated/session.py +++ b/db/generated/session.py @@ -4,7 +4,7 @@ # source: session.sql import dataclasses import datetime -from typing import Optional +from typing import AsyncIterator, Optional import uuid import sqlalchemy @@ -51,6 +51,13 @@ """ +LIST_SESSIONS_BY_USER = """-- name: list_sessions_by_user \\:many +SELECT id, user_id, device_id, created_at, last_active, expires_at +FROM user_sessions +WHERE user_id = :p1 +""" + + UPDATE_SESSION_ACTIVITY = """-- name: update_session_activity \\:exec UPDATE user_sessions SET last_active = NOW() @@ -135,6 +142,18 @@ async def get_session_by_id(self, *, id: uuid.UUID) -> Optional[models.UserSessi expires_at=row[5], ) + async def list_sessions_by_user(self, *, user_id: uuid.UUID) -> AsyncIterator[models.UserSession]: + result = await self._conn.stream(sqlalchemy.text(LIST_SESSIONS_BY_USER), {"p1": user_id}) + async for row in result: + yield models.UserSession( + id=row[0], + user_id=row[1], + device_id=row[2], + created_at=row[3], + last_active=row[4], + expires_at=row[5], + ) + async def update_session_activity(self, *, id: uuid.UUID) -> None: await self._conn.execute(sqlalchemy.text(UPDATE_SESSION_ACTIVITY), {"p1": id}) diff --git a/db/generated/user.py b/db/generated/user.py index 2599d3a..823be6a 100644 --- a/db/generated/user.py +++ b/db/generated/user.py @@ -14,7 +14,7 @@ CREATE_USER = """-- name: create_user \\:one INSERT INTO users (email, hashed_password) VALUES (:p1, :p2) -RETURNING id, email, hashed_password, created_at, updated_at, display_name, face_embedding, deleted_at +RETURNING id, email, hashed_password, created_at, updated_at, display_name, face_embedding, blocked, deleted_at """ @@ -25,33 +25,53 @@ GET_USER_BY_EMAIL = """-- name: get_user_by_email \\:one -SELECT id, email, hashed_password, created_at, updated_at, display_name, face_embedding, deleted_at +SELECT id, email, hashed_password, created_at, updated_at, display_name, face_embedding, blocked, deleted_at FROM users WHERE email = :p1 """ GET_USER_BY_ID = """-- name: get_user_by_id \\:one -SELECT id, email, hashed_password, created_at, updated_at, display_name, face_embedding, deleted_at +SELECT id, email, hashed_password, created_at, updated_at, display_name, face_embedding, blocked, deleted_at FROM users WHERE id = :p1 """ LIST_USERS = """-- name: list_users \\:many -SELECT id, email, hashed_password, created_at, updated_at, display_name, face_embedding, deleted_at +SELECT id, email, hashed_password, created_at, updated_at, display_name, face_embedding, blocked, deleted_at FROM users ORDER BY created_at DESC LIMIT :p1 OFFSET :p2 """ +SET_USER_BLOCKED = """-- name: set_user_blocked \\:one +UPDATE users +SET blocked = :p1, + updated_at = NOW() +WHERE id = :p2 +RETURNING id, email, hashed_password, created_at, updated_at, display_name, face_embedding, blocked, deleted_at +""" + + SET_USER_EMBEDDING = """-- name: set_user_embedding \\:one UPDATE users SET face_embedding = :p1\\:\\:vector, updated_at = NOW() WHERE id = :p2 -RETURNING id, email, hashed_password, created_at, updated_at, display_name, face_embedding, deleted_at +RETURNING id, email, hashed_password, created_at, updated_at, display_name, face_embedding, blocked, deleted_at +""" + + +UPDATE_USER = """-- name: update_user \\:one +UPDATE users +SET email = COALESCE(:p1, email), + display_name = COALESCE(:p2, display_name), + blocked = COALESCE(:p3, blocked), + updated_at = NOW() +WHERE id = :p4 +RETURNING id, email, hashed_password, created_at, updated_at, display_name, face_embedding, blocked, deleted_at """ @@ -60,7 +80,7 @@ SET hashed_password = :p1, updated_at = NOW() WHERE id = :p2 -RETURNING id, email, hashed_password, created_at, updated_at, display_name, face_embedding, deleted_at +RETURNING id, email, hashed_password, created_at, updated_at, display_name, face_embedding, blocked, deleted_at """ @@ -80,7 +100,8 @@ async def create_user(self, *, email: str, hashed_password: Optional[str]) -> Op updated_at=row[4], display_name=row[5], face_embedding=row[6], - deleted_at=row[7], + blocked=row[7], + deleted_at=row[8], ) async def delete_user(self, *, id: uuid.UUID) -> None: @@ -98,7 +119,8 @@ async def get_user_by_email(self, *, email: str) -> Optional[models.User]: updated_at=row[4], display_name=row[5], face_embedding=row[6], - deleted_at=row[7], + blocked=row[7], + deleted_at=row[8], ) async def get_user_by_id(self, *, id: uuid.UUID) -> Optional[models.User]: @@ -113,7 +135,8 @@ async def get_user_by_id(self, *, id: uuid.UUID) -> Optional[models.User]: updated_at=row[4], display_name=row[5], face_embedding=row[6], - deleted_at=row[7], + blocked=row[7], + deleted_at=row[8], ) async def list_users(self, *, limit: int, offset: int) -> AsyncIterator[models.User]: @@ -127,9 +150,26 @@ async def list_users(self, *, limit: int, offset: int) -> AsyncIterator[models.U updated_at=row[4], display_name=row[5], face_embedding=row[6], - deleted_at=row[7], + blocked=row[7], + deleted_at=row[8], ) + async def set_user_blocked(self, *, blocked: bool, id: uuid.UUID) -> Optional[models.User]: + row = (await self._conn.execute(sqlalchemy.text(SET_USER_BLOCKED), {"p1": blocked, "p2": id})).first() + if row is None: + return None + return models.User( + id=row[0], + email=row[1], + hashed_password=row[2], + created_at=row[3], + updated_at=row[4], + display_name=row[5], + face_embedding=row[6], + blocked=row[7], + deleted_at=row[8], + ) + async def set_user_embedding(self, *, dollar_1: Any, id: uuid.UUID) -> Optional[models.User]: row = (await self._conn.execute(sqlalchemy.text(SET_USER_EMBEDDING), {"p1": dollar_1, "p2": id})).first() if row is None: @@ -142,7 +182,29 @@ async def set_user_embedding(self, *, dollar_1: Any, id: uuid.UUID) -> Optional[ updated_at=row[4], display_name=row[5], face_embedding=row[6], - deleted_at=row[7], + blocked=row[7], + deleted_at=row[8], + ) + + async def update_user(self, *, email: str, display_name: Optional[str], blocked: bool, id: uuid.UUID) -> Optional[models.User]: + row = (await self._conn.execute(sqlalchemy.text(UPDATE_USER), { + "p1": email, + "p2": display_name, + "p3": blocked, + "p4": id, + })).first() + if row is None: + return None + return models.User( + id=row[0], + email=row[1], + hashed_password=row[2], + created_at=row[3], + updated_at=row[4], + display_name=row[5], + face_embedding=row[6], + blocked=row[7], + deleted_at=row[8], ) async def update_user_password(self, *, hashed_password: Optional[str], id: uuid.UUID) -> Optional[models.User]: @@ -157,5 +219,6 @@ async def update_user_password(self, *, hashed_password: Optional[str], id: uuid updated_at=row[4], display_name=row[5], face_embedding=row[6], - deleted_at=row[7], + blocked=row[7], + deleted_at=row[8], ) diff --git a/db/queries/session.sql b/db/queries/session.sql index 2a5b859..b22911e 100644 --- a/db/queries/session.sql +++ b/db/queries/session.sql @@ -28,6 +28,11 @@ SELECT * FROM user_sessions WHERE id = $1; +-- name: ListSessionsByUser :many +SELECT * +FROM user_sessions +WHERE user_id = $1; + -- name: UpdateSessionActivity :exec UPDATE user_sessions SET last_active = NOW() diff --git a/db/queries/user.sql b/db/queries/user.sql index b9e984e..bc3fdd8 100644 --- a/db/queries/user.sql +++ b/db/queries/user.sql @@ -20,6 +20,22 @@ SET hashed_password = $1, WHERE id = $2 RETURNING *; +-- name: UpdateUser :one +UPDATE users +SET email = COALESCE($1, email), + display_name = COALESCE($2, display_name), + blocked = COALESCE($3, blocked), + updated_at = NOW() +WHERE id = $4 +RETURNING *; + +-- name: SetUserBlocked :one +UPDATE users +SET blocked = $1, + updated_at = NOW() +WHERE id = $2 +RETURNING *; + -- name: DeleteUser :exec DELETE FROM users WHERE id = $1; diff --git a/migrations/sql/down/add-blocked-to-users.sql b/migrations/sql/down/add-blocked-to-users.sql new file mode 100644 index 0000000..d9bcfd4 --- /dev/null +++ b/migrations/sql/down/add-blocked-to-users.sql @@ -0,0 +1,2 @@ +ALTER TABLE users +DROP COLUMN blocked; diff --git a/migrations/sql/up/add-blocked-to-users.sql b/migrations/sql/up/add-blocked-to-users.sql new file mode 100644 index 0000000..c35e6fd --- /dev/null +++ b/migrations/sql/up/add-blocked-to-users.sql @@ -0,0 +1,2 @@ +ALTER TABLE users +ADD COLUMN blocked BOOLEAN NOT NULL DEFAULT FALSE; diff --git a/migrations/versions/5b6615c9ab1d_merge_heads.py b/migrations/versions/5b6615c9ab1d_merge_heads.py new file mode 100644 index 0000000..cea5228 --- /dev/null +++ b/migrations/versions/5b6615c9ab1d_merge_heads.py @@ -0,0 +1,28 @@ +"""merge_heads + +Revision ID: 5b6615c9ab1d +Revises: 9f1c3c6e9c1a, c3b8d0f1e2a4 +Create Date: 2026-03-20 02:33:56.591359 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '5b6615c9ab1d' +down_revision: Union[str, Sequence[str], None] = ('9f1c3c6e9c1a', 'c3b8d0f1e2a4') +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + pass + + +def downgrade() -> None: + """Downgrade schema.""" + pass diff --git a/migrations/versions/9f1c3c6e9c1a_add_blocked_to_users.py b/migrations/versions/9f1c3c6e9c1a_add_blocked_to_users.py new file mode 100644 index 0000000..21b14d1 --- /dev/null +++ b/migrations/versions/9f1c3c6e9c1a_add_blocked_to_users.py @@ -0,0 +1,25 @@ +"""add-blocked-to-users + +Revision ID: 9f1c3c6e9c1a +Revises: 5ead72a95638 +Create Date: 2026-03-20 12:50:00.000000 + +""" +from typing import Sequence, Union + +from migrations.helper import run_sql_down, run_sql_up + + +# revision identifiers, used by Alembic. +revision: str = "9f1c3c6e9c1a" +down_revision: Union[str, Sequence[str], None] = "5ead72a95638" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + run_sql_up("add-blocked-to-users") + + +def downgrade() -> None: + run_sql_down("add-blocked-to-users")