cherry-ai/backend/app/services/session_service.py

83 lines
3.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# File: backend/app/services/session_service.py (新建)
# Description: 管理会话数据的服务 (内存实现)
from typing import Dict, List, Optional
from app.models.pydantic_models import SessionRead, SessionCreateRequest, SessionCreateResponse, AssistantRead
from app.services.assistant_service import assistant_service_instance # 需要获取助手信息
from datetime import datetime, timezone
import uuid
# 导入 ChatService 以调用 LLM 生成标题 (避免循环导入,考虑重构)
# from app.services.chat_service import chat_service_instance
# 内存数据库存储会话
# key: session_id (str), value: SessionRead object
sessions_db: Dict[str, SessionRead] = {}
class SessionService:
"""会话数据的 CRUD 及标题生成服务"""
async def create_session(self, session_data: SessionCreateRequest) -> SessionCreateResponse:
"""创建新会话并生成标题"""
assistant = assistant_service_instance.get_assistant(session_data.assistant_id)
if not assistant:
raise ValueError("指定的助手不存在")
new_id = f"session-{uuid.uuid4()}"
created_time = datetime.now(timezone.utc)
# --- TODO: 调用 LLM 生成标题 ---
# title_prompt = f"根据以下用户第一条消息为此对话生成一个简洁的标题不超过10个字:\n\n{session_data.first_message}"
# generated_title = await chat_service_instance.generate_text(title_prompt) # 需要一个简单的文本生成方法
# 模拟标题生成
generated_title = f"关于 \"{session_data.first_message[:15]}...\""
print(f"为新会话 {new_id} 生成标题: {generated_title}")
# --- 模拟结束 ---
new_session = SessionRead(
id=new_id,
title=generated_title,
assistant_id=session_data.assistant_id,
created_at=created_time.isoformat() # 存储 ISO 格式字符串
)
sessions_db[new_id] = new_session
print(f"会话已创建: {new_id}")
return SessionCreateResponse(
id=new_session.id,
title=new_session.title,
assistant_id=new_session.assistant_id,
created_at=new_session.created_at
)
def get_sessions_by_assistant(self, assistant_id: str) -> List[SessionRead]:
"""获取指定助手的所有会话"""
return [s for s in sessions_db.values() if s.assistant_id == assistant_id]
def get_session(self, session_id: str) -> Optional[SessionRead]:
"""获取单个会话"""
return sessions_db.get(session_id)
def delete_session(self, session_id: str) -> bool:
"""删除会话"""
if session_id in sessions_db:
del sessions_db[session_id]
print(f"会话已删除: {session_id}")
# TODO: 删除关联的消息
return True
return False
def delete_sessions_by_assistant(self, assistant_id: str) -> int:
"""删除指定助手的所有会话"""
ids_to_delete = [s.id for s in sessions_db.values() if s.assistant_id == assistant_id]
count = 0
for session_id in ids_to_delete:
if self.delete_session(session_id):
count += 1
print(f"删除了助手 {assistant_id}{count} 个会话")
return count
# 创建服务实例
session_service_instance = SessionService()