138 lines
6.1 KiB
Python
138 lines
6.1 KiB
Python
# File: backend/app/services/chat_service.py (Update with DB for history)
|
|
# Description: 封装 LangChain 聊天逻辑 (使用数据库存储和检索消息)
|
|
|
|
from langchain_openai import ChatOpenAI
|
|
from langchain_google_genai import ChatGoogleGenerativeAI
|
|
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
|
from langchain_core.output_parsers import StrOutputParser
|
|
from langchain_core.messages import HumanMessage, AIMessage, BaseMessage, SystemMessage
|
|
from typing import Dict, List, Optional, Tuple
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy.future import select
|
|
from app.db.models import MessageModel, AssistantModel # Import DB models
|
|
from app.services.assistant_service import AssistantService # Use class directly
|
|
from app.models.pydantic_models import AssistantRead
|
|
|
|
class ChatService:
|
|
"""处理 AI 聊天交互的服务 (使用数据库历史)"""
|
|
|
|
def __init__(self, default_api_key: str):
|
|
self.default_api_key = default_api_key
|
|
self.assistant_service = AssistantService() # Instantiate assistant service
|
|
|
|
def _get_llm(self, assistant: AssistantRead) -> ChatOpenAI:
|
|
# ... (remains the same) ...
|
|
if assistant.model.startswith("gpt"):
|
|
return ChatOpenAI(model=assistant.model, api_key=self.default_api_key, temperature=assistant.temperature)
|
|
elif assistant.model.startswith("gemini"):
|
|
return ChatGoogleGenerativeAI(
|
|
model=assistant.model,
|
|
api_key=self.default_api_key, # 或从助手配置中读取特定 key
|
|
temperature=assistant.temperature
|
|
)
|
|
else:
|
|
print(f"警告: 模型 {assistant.model} 未明确支持,尝试使用 ChatOpenAI")
|
|
return ChatOpenAI(model=assistant.model, api_key=self.default_api_key, temperature=assistant.temperature)
|
|
|
|
|
|
async def _get_chat_history(self, db: AsyncSession, session_id: str, limit: int = 10) -> Tuple[List[BaseMessage], int]:
|
|
"""从数据库加载指定会话的历史消息 (按 order 排序)"""
|
|
stmt = (
|
|
select(MessageModel)
|
|
.filter(MessageModel.session_id == session_id)
|
|
.order_by(MessageModel.order.desc()) # Get latest first
|
|
.limit(limit)
|
|
)
|
|
result = await db.execute(stmt)
|
|
db_messages = result.scalars().all()
|
|
|
|
# Convert to LangChain messages (in correct order: oldest to newest)
|
|
history: List[BaseMessage] = []
|
|
max_order = 0
|
|
for msg in reversed(db_messages): # Reverse to get chronological order
|
|
if msg.sender == 'user':
|
|
history.append(HumanMessage(content=msg.text))
|
|
elif msg.sender == 'ai':
|
|
history.append(AIMessage(content=msg.text))
|
|
# Add handling for 'system' if needed
|
|
max_order = max(max_order, msg.order) # Keep track of the latest order number
|
|
|
|
return history, max_order
|
|
|
|
async def _save_message(self, db: AsyncSession, session_id: str, sender: str, text: str, order: int):
|
|
"""将消息保存到数据库"""
|
|
db_message = MessageModel(
|
|
session_id=session_id,
|
|
sender=sender,
|
|
text=text,
|
|
order=order
|
|
)
|
|
db.add(db_message)
|
|
await db.flush() # Ensure it's added before potential commit
|
|
print(f"消息已保存 (DB): Session={session_id}, Order={order}, Sender={sender}")
|
|
|
|
|
|
async def get_ai_reply(self, db: AsyncSession, user_message: str, session_id: str, assistant_id: str) -> str:
|
|
"""获取 AI 回复,并保存用户消息和 AI 回复到数据库"""
|
|
# 1. 获取助手配置
|
|
assistant = await self.assistant_service.get_assistant(db, assistant_id)
|
|
if not assistant:
|
|
raise ValueError(f"找不到助手 ID: {assistant_id}")
|
|
|
|
# 2. 获取历史记录和下一个序号
|
|
current_chat_history, last_order = await self._get_chat_history(db, session_id)
|
|
user_message_order = last_order + 1
|
|
ai_message_order = last_order + 2
|
|
|
|
# 3. 构建 Prompt
|
|
prompt = ChatPromptTemplate.from_messages([
|
|
SystemMessage(content=assistant.system_prompt),
|
|
MessagesPlaceholder(variable_name="chat_history"),
|
|
HumanMessage(content=user_message),
|
|
])
|
|
|
|
# 4. 获取 LLM
|
|
llm = self._get_llm(assistant)
|
|
output_parser = StrOutputParser()
|
|
chain = prompt | llm | output_parser
|
|
|
|
try:
|
|
# --- Save user message BEFORE calling LLM ---
|
|
await self._save_message(db, session_id, 'user', user_message, user_message_order)
|
|
|
|
# 5. 调用链获取回复
|
|
ai_response = await chain.ainvoke({
|
|
"input": user_message,
|
|
"chat_history": current_chat_history, # Pass history fetched from DB
|
|
})
|
|
|
|
# --- Save AI response AFTER getting it ---
|
|
await self._save_message(db, session_id, 'ai', ai_response, ai_message_order)
|
|
|
|
# Note: We don't need to manage history in memory anymore (chat_history_db removed)
|
|
return ai_response
|
|
|
|
except Exception as e:
|
|
# Consider rolling back the user message save if LLM call fails,
|
|
# although often it's better to keep the user message.
|
|
# await db.rollback() # Handled by get_db_session dependency on error
|
|
print(f"调用 LangChain 时出错 (助手: {assistant_id}, 会话: {session_id}): {e}")
|
|
raise Exception(f"AI 服务暂时不可用: {e}")
|
|
|
|
async def generate_text(self, prompt_text: str, model_name: str = "gemini-2.0-flash", temperature: float = 0.5) -> str:
|
|
# ... (remains the same) ...
|
|
try:
|
|
temp_llm = ChatGoogleGenerativeAI(
|
|
model=model_name,
|
|
api_key=self.default_api_key, # 或从助手配置中读取特定 key
|
|
temperature=temperature
|
|
)
|
|
response = await temp_llm.ainvoke(prompt_text)
|
|
return response.content
|
|
except Exception as e:
|
|
print(f"生成文本时出错: {e}")
|
|
return "无法生成标题"
|
|
|
|
|
|
# ChatService instance is now created where needed or injected, no global instance here.
|