cherry-ai/backend/app/services/chat_service.py
2025-04-30 04:39:36 +08:00

138 lines
6.1 KiB
Python

# File: backend/app/services/chat_service.py (Update with DB for history)
# Description: 封装 LangChain 聊天逻辑 (使用数据库存储和检索消息)
from langchain_openai import ChatOpenAI
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.output_parsers import StrOutputParser
from langchain_core.messages import HumanMessage, AIMessage, BaseMessage, SystemMessage
from typing import Dict, List, Optional, Tuple
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.db.models import MessageModel, AssistantModel # Import DB models
from app.services.assistant_service import AssistantService # Use class directly
from app.models.pydantic_models import AssistantRead
class ChatService:
"""处理 AI 聊天交互的服务 (使用数据库历史)"""
def __init__(self, default_api_key: str):
self.default_api_key = default_api_key
self.assistant_service = AssistantService() # Instantiate assistant service
def _get_llm(self, assistant: AssistantRead) -> ChatOpenAI:
# ... (remains the same) ...
if assistant.model.startswith("gpt"):
return ChatOpenAI(model=assistant.model, api_key=self.default_api_key, temperature=assistant.temperature)
elif assistant.model.startswith("gemini"):
return ChatGoogleGenerativeAI(
model=assistant.model,
api_key=self.default_api_key, # 或从助手配置中读取特定 key
temperature=assistant.temperature
)
else:
print(f"警告: 模型 {assistant.model} 未明确支持,尝试使用 ChatOpenAI")
return ChatOpenAI(model=assistant.model, api_key=self.default_api_key, temperature=assistant.temperature)
async def _get_chat_history(self, db: AsyncSession, session_id: str, limit: int = 10) -> Tuple[List[BaseMessage], int]:
"""从数据库加载指定会话的历史消息 (按 order 排序)"""
stmt = (
select(MessageModel)
.filter(MessageModel.session_id == session_id)
.order_by(MessageModel.order.desc()) # Get latest first
.limit(limit)
)
result = await db.execute(stmt)
db_messages = result.scalars().all()
# Convert to LangChain messages (in correct order: oldest to newest)
history: List[BaseMessage] = []
max_order = 0
for msg in reversed(db_messages): # Reverse to get chronological order
if msg.sender == 'user':
history.append(HumanMessage(content=msg.text))
elif msg.sender == 'ai':
history.append(AIMessage(content=msg.text))
# Add handling for 'system' if needed
max_order = max(max_order, msg.order) # Keep track of the latest order number
return history, max_order
async def _save_message(self, db: AsyncSession, session_id: str, sender: str, text: str, order: int):
"""将消息保存到数据库"""
db_message = MessageModel(
session_id=session_id,
sender=sender,
text=text,
order=order
)
db.add(db_message)
await db.flush() # Ensure it's added before potential commit
print(f"消息已保存 (DB): Session={session_id}, Order={order}, Sender={sender}")
async def get_ai_reply(self, db: AsyncSession, user_message: str, session_id: str, assistant_id: str) -> str:
"""获取 AI 回复,并保存用户消息和 AI 回复到数据库"""
# 1. 获取助手配置
assistant = await self.assistant_service.get_assistant(db, assistant_id)
if not assistant:
raise ValueError(f"找不到助手 ID: {assistant_id}")
# 2. 获取历史记录和下一个序号
current_chat_history, last_order = await self._get_chat_history(db, session_id)
user_message_order = last_order + 1
ai_message_order = last_order + 2
# 3. 构建 Prompt
prompt = ChatPromptTemplate.from_messages([
SystemMessage(content=assistant.system_prompt),
MessagesPlaceholder(variable_name="chat_history"),
HumanMessage(content=user_message),
])
# 4. 获取 LLM
llm = self._get_llm(assistant)
output_parser = StrOutputParser()
chain = prompt | llm | output_parser
try:
# --- Save user message BEFORE calling LLM ---
await self._save_message(db, session_id, 'user', user_message, user_message_order)
# 5. 调用链获取回复
ai_response = await chain.ainvoke({
"input": user_message,
"chat_history": current_chat_history, # Pass history fetched from DB
})
# --- Save AI response AFTER getting it ---
await self._save_message(db, session_id, 'ai', ai_response, ai_message_order)
# Note: We don't need to manage history in memory anymore (chat_history_db removed)
return ai_response
except Exception as e:
# Consider rolling back the user message save if LLM call fails,
# although often it's better to keep the user message.
# await db.rollback() # Handled by get_db_session dependency on error
print(f"调用 LangChain 时出错 (助手: {assistant_id}, 会话: {session_id}): {e}")
raise Exception(f"AI 服务暂时不可用: {e}")
async def generate_text(self, prompt_text: str, model_name: str = "gemini-2.0-flash", temperature: float = 0.5) -> str:
# ... (remains the same) ...
try:
temp_llm = ChatGoogleGenerativeAI(
model=model_name,
api_key=self.default_api_key, # 或从助手配置中读取特定 key
temperature=temperature
)
response = await temp_llm.ainvoke(prompt_text)
return response.content
except Exception as e:
print(f"生成文本时出错: {e}")
return "无法生成标题"
# ChatService instance is now created where needed or injected, no global instance here.