93 lines
4.0 KiB
Python
93 lines
4.0 KiB
Python
# File: backend/app/services/chat_service.py (新建)
|
||
# Description: 封装 LangChain 聊天逻辑
|
||
|
||
from langchain_openai import ChatOpenAI
|
||
from langchain_google_genai import ChatGoogleGenerativeAI # 如果使用 Google
|
||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||
from langchain_core.output_parsers import StrOutputParser
|
||
from langchain_core.messages import HumanMessage, AIMessage
|
||
|
||
# --- 可选:添加内存管理 ---
|
||
# 简单的内存实现 (可以替换为更复杂的 LangChain Memory 类)
|
||
chat_history = {} # 使用字典存储不同会话的内存,需要 session_id
|
||
|
||
class ChatService:
|
||
"""处理 AI 聊天交互的服务"""
|
||
|
||
def __init__(self, api_key: str):
|
||
"""
|
||
初始化 ChatService
|
||
Args:
|
||
api_key (str): 用于 LLM 的 API 密钥
|
||
"""
|
||
# --- 选择并初始化 LLM ---
|
||
# 使用 OpenAI GPT-3.5 Turbo (推荐) 或 GPT-4
|
||
# self.llm = ChatOpenAI(model="gpt-3.5-turbo", api_key=api_key, temperature=0.7)
|
||
# self.llm = ChatOpenAI(model="gpt-4", api_key=api_key, temperature=0.7)
|
||
|
||
# --- 如果使用 Google Gemini ---
|
||
self.llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", google_api_key=api_key, convert_system_message_to_human=True)
|
||
|
||
# --- 定义 Prompt 模板 ---
|
||
# 包含系统消息、历史记录占位符和当前用户输入
|
||
self.prompt = ChatPromptTemplate.from_messages([
|
||
("system", "你是一个乐于助人的 AI 助手,请用简洁明了的语言回答问题。你的名字叫 CherryAI。"),
|
||
MessagesPlaceholder(variable_name="chat_history"), # 用于插入历史消息
|
||
("human", "{input}"), # 用户当前输入
|
||
])
|
||
|
||
# --- 定义输出解析器 ---
|
||
# 将 LLM 的输出解析为字符串
|
||
self.output_parser = StrOutputParser()
|
||
|
||
# --- 构建 LangChain Expression Language (LCEL) 链 ---
|
||
self.chain = self.prompt | self.llm | self.output_parser
|
||
|
||
async def get_ai_reply(self, user_message: str, session_id: str = "default_session") -> str:
|
||
"""
|
||
获取 AI 对用户消息的回复 (异步)
|
||
Args:
|
||
user_message (str): 用户发送的消息
|
||
session_id (str): (可选) 用于区分不同对话的会话 ID,以支持内存
|
||
Returns:
|
||
str: AI 的回复文本
|
||
Raises:
|
||
Exception: 如果调用 AI 服务时发生错误
|
||
"""
|
||
try:
|
||
# --- 获取当前会话的历史记录 (如果需要内存) ---
|
||
current_chat_history = chat_history.get(session_id, [])
|
||
|
||
# --- 使用 ainvoke 进行异步调用 ---
|
||
ai_response = await self.chain.ainvoke({
|
||
"input": user_message,
|
||
"chat_history": current_chat_history, # 传入历史记录
|
||
})
|
||
|
||
# --- 更新会话历史记录 (如果需要内存) ---
|
||
# 只保留最近 N 轮对话,防止历史过长
|
||
max_history_length = 10 # 保留最近 5 轮对话 (10条消息)
|
||
current_chat_history.append(HumanMessage(content=user_message))
|
||
current_chat_history.append(AIMessage(content=ai_response))
|
||
# 如果历史记录超过长度,移除最早的消息
|
||
if len(current_chat_history) > max_history_length:
|
||
chat_history[session_id] = current_chat_history[-max_history_length:]
|
||
else:
|
||
chat_history[session_id] = current_chat_history
|
||
|
||
return ai_response
|
||
|
||
except Exception as e:
|
||
print(f"调用 LangChain 时出错: {e}")
|
||
# 可以进行更细致的错误处理,例如区分 API 错误和内部错误
|
||
raise Exception(f"AI 服务暂时不可用: {e}")
|
||
|
||
|
||
# --- 在文件末尾,创建 ChatService 的实例 ---
|
||
# 从配置中获取 API Key
|
||
from app.core.config import GOOGLE_API_KEY
|
||
if not GOOGLE_API_KEY:
|
||
raise ValueError("请在 .env 文件中设置 GOOGLE_API_KEY")
|
||
|
||
chat_service_instance = ChatService(api_key=GOOGLE_API_KEY)
|