cherry-ai/backend/app/services/session_service.py
2025-04-30 04:39:36 +08:00

74 lines
3.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# File: backend/app/services/session_service.py (Update with DB)
# Description: 管理会话数据的服务 (使用 SQLAlchemy)
from typing import List, Optional
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy import delete as sqlalchemy_delete
from app.db.models import SessionModel, AssistantModel # Import DB models
from app.models.pydantic_models import SessionRead, SessionCreateRequest, SessionCreateResponse
from datetime import datetime, timezone
# Import ChatService for title generation (consider refactoring later)
from app.services.chat_service import ChatService
import app.core.config as Config
chat_service_instance = ChatService(Config.GOOGLE_API_KEY)
class SessionService:
"""会话数据的 CRUD 及标题生成服务 (数据库版)"""
async def create_session(self, db: AsyncSession, session_data: SessionCreateRequest) -> SessionCreateResponse:
"""创建新会话并生成标题"""
# 检查助手是否存在
result = await db.execute(select(AssistantModel).filter(AssistantModel.id == session_data.assistant_id))
assistant = result.scalars().first()
if not assistant:
raise ValueError("指定的助手不存在")
# --- 调用 LLM 生成标题 ---
try:
title_prompt = f"根据以下用户第一条消息为此对话生成一个简洁的标题不超过10个字:\n\n{session_data.first_message}"
generated_title = await chat_service_instance.generate_text(title_prompt)
except Exception as e:
print(f"生成会话标题时出错: {e}")
generated_title = f"关于 \"{session_data.first_message[:15]}...\"" # Fallback
# --- 生成结束 ---
db_session = SessionModel(
title=generated_title,
assistant_id=session_data.assistant_id
# ID and created_at have defaults
)
db.add(db_session)
await db.flush()
await db.refresh(db_session)
print(f"会话已创建 (DB): {db_session.id}")
return SessionCreateResponse(
id=db_session.id,
title=db_session.title,
assistant_id=db_session.assistant_id,
created_at=db_session.created_at.isoformat() # Use datetime from DB model
)
async def get_sessions_by_assistant(self, db: AsyncSession, assistant_id: str) -> List[SessionRead]:
"""获取指定助手的所有会话"""
stmt = select(SessionModel).filter(SessionModel.assistant_id == assistant_id).order_by(SessionModel.updated_at.desc()) # Order by update time
result = await db.execute(stmt)
sessions = result.scalars().all()
return [SessionRead.model_validate(s) for s in sessions]
async def get_session(self, db: AsyncSession, session_id: str) -> Optional[SessionRead]:
"""获取单个会话"""
result = await db.execute(select(SessionModel).filter(SessionModel.id == session_id))
session = result.scalars().first()
return SessionRead.model_validate(session) if session else None
async def delete_session(self, db: AsyncSession, session_id: str) -> bool:
"""删除会话"""
stmt = sqlalchemy_delete(SessionModel).where(SessionModel.id == session_id)
result = await db.execute(stmt)
if result.rowcount > 0:
await db.flush()
print(f"会话已删除 (DB): {session_id}")
# Deletion of messages handled by cascade
return True
return False