# File: backend/app/services/chat_service.py (新建) # Description: 封装 LangChain 聊天逻辑 from langchain_openai import ChatOpenAI from langchain_google_genai import ChatGoogleGenerativeAI # 如果使用 Google from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain_core.output_parsers import StrOutputParser from langchain_core.messages import HumanMessage, AIMessage # --- 可选:添加内存管理 --- # 简单的内存实现 (可以替换为更复杂的 LangChain Memory 类) chat_history = {} # 使用字典存储不同会话的内存,需要 session_id class ChatService: """处理 AI 聊天交互的服务""" def __init__(self, api_key: str): """ 初始化 ChatService Args: api_key (str): 用于 LLM 的 API 密钥 """ # --- 选择并初始化 LLM --- # 使用 OpenAI GPT-3.5 Turbo (推荐) 或 GPT-4 # self.llm = ChatOpenAI(model="gpt-3.5-turbo", api_key=api_key, temperature=0.7) # self.llm = ChatOpenAI(model="gpt-4", api_key=api_key, temperature=0.7) # --- 如果使用 Google Gemini --- self.llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", google_api_key=api_key, convert_system_message_to_human=True) # --- 定义 Prompt 模板 --- # 包含系统消息、历史记录占位符和当前用户输入 self.prompt = ChatPromptTemplate.from_messages([ ("system", "你是一个乐于助人的 AI 助手,请用简洁明了的语言回答问题。你的名字叫 CherryAI。"), MessagesPlaceholder(variable_name="chat_history"), # 用于插入历史消息 ("human", "{input}"), # 用户当前输入 ]) # --- 定义输出解析器 --- # 将 LLM 的输出解析为字符串 self.output_parser = StrOutputParser() # --- 构建 LangChain Expression Language (LCEL) 链 --- self.chain = self.prompt | self.llm | self.output_parser async def get_ai_reply(self, user_message: str, session_id: str = "default_session") -> str: """ 获取 AI 对用户消息的回复 (异步) Args: user_message (str): 用户发送的消息 session_id (str): (可选) 用于区分不同对话的会话 ID,以支持内存 Returns: str: AI 的回复文本 Raises: Exception: 如果调用 AI 服务时发生错误 """ try: # --- 获取当前会话的历史记录 (如果需要内存) --- current_chat_history = chat_history.get(session_id, []) # --- 使用 ainvoke 进行异步调用 --- ai_response = await self.chain.ainvoke({ "input": user_message, "chat_history": current_chat_history, # 传入历史记录 }) # --- 更新会话历史记录 (如果需要内存) --- # 只保留最近 N 轮对话,防止历史过长 max_history_length = 10 # 保留最近 5 轮对话 (10条消息) current_chat_history.append(HumanMessage(content=user_message)) current_chat_history.append(AIMessage(content=ai_response)) # 如果历史记录超过长度,移除最早的消息 if len(current_chat_history) > max_history_length: chat_history[session_id] = current_chat_history[-max_history_length:] else: chat_history[session_id] = current_chat_history return ai_response except Exception as e: print(f"调用 LangChain 时出错: {e}") # 可以进行更细致的错误处理,例如区分 API 错误和内部错误 raise Exception(f"AI 服务暂时不可用: {e}") # --- 在文件末尾,创建 ChatService 的实例 --- # 从配置中获取 API Key from app.core.config import GOOGLE_API_KEY if not GOOGLE_API_KEY: raise ValueError("请在 .env 文件中设置 GOOGLE_API_KEY") chat_service_instance = ChatService(api_key=GOOGLE_API_KEY)