# File: backend/app/services/session_service.py (Update with DB) # Description: 管理会话数据的服务 (使用 SQLAlchemy) from typing import List, Optional from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select from sqlalchemy import delete as sqlalchemy_delete from app.db.models import SessionModel, AssistantModel # Import DB models from app.models.pydantic_models import SessionRead, SessionCreateRequest, SessionCreateResponse from datetime import datetime, timezone # Import ChatService for title generation (consider refactoring later) from app.services.chat_service import ChatService import app.core.config as Config chat_service_instance = ChatService(Config.GOOGLE_API_KEY) class SessionService: """会话数据的 CRUD 及标题生成服务 (数据库版)""" async def create_session(self, db: AsyncSession, session_data: SessionCreateRequest) -> SessionCreateResponse: """创建新会话并生成标题""" # 检查助手是否存在 result = await db.execute(select(AssistantModel).filter(AssistantModel.id == session_data.assistant_id)) assistant = result.scalars().first() if not assistant: raise ValueError("指定的助手不存在") # --- 调用 LLM 生成标题 --- try: title_prompt = f"根据以下用户第一条消息,为此对话生成一个简洁的标题(不超过10个字):\n\n{session_data.first_message}" generated_title = await chat_service_instance.generate_text(title_prompt) except Exception as e: print(f"生成会话标题时出错: {e}") generated_title = f"关于 \"{session_data.first_message[:15]}...\"" # Fallback # --- 生成结束 --- db_session = SessionModel( title=generated_title, assistant_id=session_data.assistant_id # ID and created_at have defaults ) db.add(db_session) await db.flush() await db.refresh(db_session) print(f"会话已创建 (DB): {db_session.id}") return SessionCreateResponse( id=db_session.id, title=db_session.title, assistant_id=db_session.assistant_id, created_at=db_session.created_at.isoformat() # Use datetime from DB model ) async def get_sessions_by_assistant(self, db: AsyncSession, assistant_id: str) -> List[SessionRead]: """获取指定助手的所有会话""" stmt = select(SessionModel).filter(SessionModel.assistant_id == assistant_id).order_by(SessionModel.updated_at.desc()) # Order by update time result = await db.execute(stmt) sessions = result.scalars().all() return [SessionRead.model_validate(s) for s in sessions] async def get_session(self, db: AsyncSession, session_id: str) -> Optional[SessionRead]: """获取单个会话""" result = await db.execute(select(SessionModel).filter(SessionModel.id == session_id)) session = result.scalars().first() return SessionRead.model_validate(session) if session else None async def delete_session(self, db: AsyncSession, session_id: str) -> bool: """删除会话""" stmt = sqlalchemy_delete(SessionModel).where(SessionModel.id == session_id) result = await db.execute(stmt) if result.rowcount > 0: await db.flush() print(f"会话已删除 (DB): {session_id}") # Deletion of messages handled by cascade return True return False