cherry-ai/backend/app/services/chat_service.py
2025-04-29 18:15:16 +08:00

93 lines
4.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# File: backend/app/services/chat_service.py (新建)
# Description: 封装 LangChain 聊天逻辑
from langchain_openai import ChatOpenAI
from langchain_google_genai import ChatGoogleGenerativeAI # 如果使用 Google
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.output_parsers import StrOutputParser
from langchain_core.messages import HumanMessage, AIMessage
# --- 可选:添加内存管理 ---
# 简单的内存实现 (可以替换为更复杂的 LangChain Memory 类)
chat_history = {} # 使用字典存储不同会话的内存,需要 session_id
class ChatService:
"""处理 AI 聊天交互的服务"""
def __init__(self, api_key: str):
"""
初始化 ChatService
Args:
api_key (str): 用于 LLM 的 API 密钥
"""
# --- 选择并初始化 LLM ---
# 使用 OpenAI GPT-3.5 Turbo (推荐) 或 GPT-4
# self.llm = ChatOpenAI(model="gpt-3.5-turbo", api_key=api_key, temperature=0.7)
# self.llm = ChatOpenAI(model="gpt-4", api_key=api_key, temperature=0.7)
# --- 如果使用 Google Gemini ---
self.llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", google_api_key=api_key, convert_system_message_to_human=True)
# --- 定义 Prompt 模板 ---
# 包含系统消息、历史记录占位符和当前用户输入
self.prompt = ChatPromptTemplate.from_messages([
("system", "你是一个乐于助人的 AI 助手,请用简洁明了的语言回答问题。你的名字叫 CherryAI。"),
MessagesPlaceholder(variable_name="chat_history"), # 用于插入历史消息
("human", "{input}"), # 用户当前输入
])
# --- 定义输出解析器 ---
# 将 LLM 的输出解析为字符串
self.output_parser = StrOutputParser()
# --- 构建 LangChain Expression Language (LCEL) 链 ---
self.chain = self.prompt | self.llm | self.output_parser
async def get_ai_reply(self, user_message: str, session_id: str = "default_session") -> str:
"""
获取 AI 对用户消息的回复 (异步)
Args:
user_message (str): 用户发送的消息
session_id (str): (可选) 用于区分不同对话的会话 ID以支持内存
Returns:
str: AI 的回复文本
Raises:
Exception: 如果调用 AI 服务时发生错误
"""
try:
# --- 获取当前会话的历史记录 (如果需要内存) ---
current_chat_history = chat_history.get(session_id, [])
# --- 使用 ainvoke 进行异步调用 ---
ai_response = await self.chain.ainvoke({
"input": user_message,
"chat_history": current_chat_history, # 传入历史记录
})
# --- 更新会话历史记录 (如果需要内存) ---
# 只保留最近 N 轮对话,防止历史过长
max_history_length = 10 # 保留最近 5 轮对话 (10条消息)
current_chat_history.append(HumanMessage(content=user_message))
current_chat_history.append(AIMessage(content=ai_response))
# 如果历史记录超过长度,移除最早的消息
if len(current_chat_history) > max_history_length:
chat_history[session_id] = current_chat_history[-max_history_length:]
else:
chat_history[session_id] = current_chat_history
return ai_response
except Exception as e:
print(f"调用 LangChain 时出错: {e}")
# 可以进行更细致的错误处理,例如区分 API 错误和内部错误
raise Exception(f"AI 服务暂时不可用: {e}")
# --- 在文件末尾,创建 ChatService 的实例 ---
# 从配置中获取 API Key
from app.core.config import GOOGLE_API_KEY
if not GOOGLE_API_KEY:
raise ValueError("请在 .env 文件中设置 GOOGLE_API_KEY")
chat_service_instance = ChatService(api_key=GOOGLE_API_KEY)