74 lines
3.5 KiB
Python
74 lines
3.5 KiB
Python
# File: backend/app/services/session_service.py (Update with DB)
|
||
# Description: 管理会话数据的服务 (使用 SQLAlchemy)
|
||
|
||
from typing import List, Optional
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
from sqlalchemy.future import select
|
||
from sqlalchemy import delete as sqlalchemy_delete
|
||
from app.db.models import SessionModel, AssistantModel # Import DB models
|
||
from app.models.pydantic_models import SessionRead, SessionCreateRequest, SessionCreateResponse
|
||
from datetime import datetime, timezone
|
||
# Import ChatService for title generation (consider refactoring later)
|
||
from app.services.chat_service import ChatService
|
||
import app.core.config as Config
|
||
chat_service_instance = ChatService(Config.GOOGLE_API_KEY)
|
||
class SessionService:
|
||
"""会话数据的 CRUD 及标题生成服务 (数据库版)"""
|
||
|
||
async def create_session(self, db: AsyncSession, session_data: SessionCreateRequest) -> SessionCreateResponse:
|
||
"""创建新会话并生成标题"""
|
||
# 检查助手是否存在
|
||
result = await db.execute(select(AssistantModel).filter(AssistantModel.id == session_data.assistant_id))
|
||
assistant = result.scalars().first()
|
||
if not assistant:
|
||
raise ValueError("指定的助手不存在")
|
||
|
||
# --- 调用 LLM 生成标题 ---
|
||
try:
|
||
title_prompt = f"根据以下用户第一条消息,为此对话生成一个简洁的标题(不超过10个字):\n\n{session_data.first_message}"
|
||
generated_title = await chat_service_instance.generate_text(title_prompt)
|
||
except Exception as e:
|
||
print(f"生成会话标题时出错: {e}")
|
||
generated_title = f"关于 \"{session_data.first_message[:15]}...\"" # Fallback
|
||
# --- 生成结束 ---
|
||
|
||
db_session = SessionModel(
|
||
title=generated_title,
|
||
assistant_id=session_data.assistant_id
|
||
# ID and created_at have defaults
|
||
)
|
||
db.add(db_session)
|
||
await db.flush()
|
||
await db.refresh(db_session)
|
||
print(f"会话已创建 (DB): {db_session.id}")
|
||
|
||
return SessionCreateResponse(
|
||
id=db_session.id,
|
||
title=db_session.title,
|
||
assistant_id=db_session.assistant_id,
|
||
created_at=db_session.created_at.isoformat() # Use datetime from DB model
|
||
)
|
||
|
||
async def get_sessions_by_assistant(self, db: AsyncSession, assistant_id: str) -> List[SessionRead]:
|
||
"""获取指定助手的所有会话"""
|
||
stmt = select(SessionModel).filter(SessionModel.assistant_id == assistant_id).order_by(SessionModel.updated_at.desc()) # Order by update time
|
||
result = await db.execute(stmt)
|
||
sessions = result.scalars().all()
|
||
return [SessionRead.model_validate(s) for s in sessions]
|
||
|
||
async def get_session(self, db: AsyncSession, session_id: str) -> Optional[SessionRead]:
|
||
"""获取单个会话"""
|
||
result = await db.execute(select(SessionModel).filter(SessionModel.id == session_id))
|
||
session = result.scalars().first()
|
||
return SessionRead.model_validate(session) if session else None
|
||
|
||
async def delete_session(self, db: AsyncSession, session_id: str) -> bool:
|
||
"""删除会话"""
|
||
stmt = sqlalchemy_delete(SessionModel).where(SessionModel.id == session_id)
|
||
result = await db.execute(stmt)
|
||
if result.rowcount > 0:
|
||
await db.flush()
|
||
print(f"会话已删除 (DB): {session_id}")
|
||
# Deletion of messages handled by cascade
|
||
return True
|
||
return False |