# File: backend/app/services/chat_service.py (Update with DB for history) # Description: 封装 LangChain 聊天逻辑 (使用数据库存储和检索消息) from langchain_openai import ChatOpenAI from langchain_google_genai import ChatGoogleGenerativeAI from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain_core.output_parsers import StrOutputParser from langchain_core.messages import HumanMessage, AIMessage, BaseMessage, SystemMessage from typing import Dict, List, Optional, Tuple from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select from app.db.models import MessageModel, AssistantModel # Import DB models from app.services.assistant_service import AssistantService # Use class directly from app.models.pydantic_models import AssistantRead class ChatService: """处理 AI 聊天交互的服务 (使用数据库历史)""" def __init__(self, default_api_key: str): self.default_api_key = default_api_key self.assistant_service = AssistantService() # Instantiate assistant service def _get_llm(self, assistant: AssistantRead) -> ChatOpenAI: # ... (remains the same) ... if assistant.model.startswith("gpt"): return ChatOpenAI(model=assistant.model, api_key=self.default_api_key, temperature=assistant.temperature) elif assistant.model.startswith("gemini"): return ChatGoogleGenerativeAI( model=assistant.model, api_key=self.default_api_key, # 或从助手配置中读取特定 key temperature=assistant.temperature ) else: print(f"警告: 模型 {assistant.model} 未明确支持,尝试使用 ChatOpenAI") return ChatOpenAI(model=assistant.model, api_key=self.default_api_key, temperature=assistant.temperature) async def _get_chat_history(self, db: AsyncSession, session_id: str, limit: int = 10) -> Tuple[List[BaseMessage], int]: """从数据库加载指定会话的历史消息 (按 order 排序)""" stmt = ( select(MessageModel) .filter(MessageModel.session_id == session_id) .order_by(MessageModel.order.desc()) # Get latest first .limit(limit) ) result = await db.execute(stmt) db_messages = result.scalars().all() # Convert to LangChain messages (in correct order: oldest to newest) history: List[BaseMessage] = [] max_order = 0 for msg in reversed(db_messages): # Reverse to get chronological order if msg.sender == 'user': history.append(HumanMessage(content=msg.text)) elif msg.sender == 'ai': history.append(AIMessage(content=msg.text)) # Add handling for 'system' if needed max_order = max(max_order, msg.order) # Keep track of the latest order number return history, max_order async def _save_message(self, db: AsyncSession, session_id: str, sender: str, text: str, order: int): """将消息保存到数据库""" db_message = MessageModel( session_id=session_id, sender=sender, text=text, order=order ) db.add(db_message) await db.flush() # Ensure it's added before potential commit print(f"消息已保存 (DB): Session={session_id}, Order={order}, Sender={sender}") async def get_ai_reply(self, db: AsyncSession, user_message: str, session_id: str, assistant_id: str) -> str: """获取 AI 回复,并保存用户消息和 AI 回复到数据库""" # 1. 获取助手配置 assistant = await self.assistant_service.get_assistant(db, assistant_id) if not assistant: raise ValueError(f"找不到助手 ID: {assistant_id}") # 2. 获取历史记录和下一个序号 current_chat_history, last_order = await self._get_chat_history(db, session_id) user_message_order = last_order + 1 ai_message_order = last_order + 2 # 3. 构建 Prompt prompt = ChatPromptTemplate.from_messages([ SystemMessage(content=assistant.system_prompt), MessagesPlaceholder(variable_name="chat_history"), HumanMessage(content=user_message), ]) # 4. 获取 LLM llm = self._get_llm(assistant) output_parser = StrOutputParser() chain = prompt | llm | output_parser try: # --- Save user message BEFORE calling LLM --- await self._save_message(db, session_id, 'user', user_message, user_message_order) # 5. 调用链获取回复 ai_response = await chain.ainvoke({ "input": user_message, "chat_history": current_chat_history, # Pass history fetched from DB }) # --- Save AI response AFTER getting it --- await self._save_message(db, session_id, 'ai', ai_response, ai_message_order) # Note: We don't need to manage history in memory anymore (chat_history_db removed) return ai_response except Exception as e: # Consider rolling back the user message save if LLM call fails, # although often it's better to keep the user message. # await db.rollback() # Handled by get_db_session dependency on error print(f"调用 LangChain 时出错 (助手: {assistant_id}, 会话: {session_id}): {e}") raise Exception(f"AI 服务暂时不可用: {e}") async def generate_text(self, prompt_text: str, model_name: str = "gemini-2.0-flash", temperature: float = 0.5) -> str: # ... (remains the same) ... try: temp_llm = ChatGoogleGenerativeAI( model=model_name, api_key=self.default_api_key, # 或从助手配置中读取特定 key temperature=temperature ) response = await temp_llm.ainvoke(prompt_text) return response.content except Exception as e: print(f"生成文本时出错: {e}") return "无法生成标题" # ChatService instance is now created where needed or injected, no global instance here.