cherry-ai/backend/app/services/chat_service.py

128 lines
5.4 KiB
Python

# File: backend/app/services/chat_service.py (更新)
# Description: 封装 LangChain 聊天逻辑 (支持助手配置和会话历史)
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.output_parsers import StrOutputParser
from langchain_core.messages import HumanMessage, AIMessage, BaseMessage, SystemMessage
from typing import Dict, List, Optional
from app.services.assistant_service import assistant_service_instance # 获取助手配置
from app.models.pydantic_models import AssistantRead # 引入助手模型
import app.core.config as Config
# --- 更新内存管理 ---
# 使用字典存储不同会话的内存
# key: session_id (str), value: List[BaseMessage]
chat_history_db: Dict[str, List[BaseMessage]] = {}
class ChatService:
"""处理 AI 聊天交互的服务 (支持助手配置)"""
def __init__(self, default_api_key: str):
"""初始化时可传入默认 API Key"""
self.default_api_key = default_api_key
# 不再在 init 中创建固定的 LLM 和 chain
def _get_llm(self, assistant: AssistantRead) -> ChatOpenAI:
"""根据助手配置动态创建 LLM 实例"""
# TODO: 支持不同模型提供商 (Gemini, Anthropic etc.)
if assistant.model.startswith("gpt"):
return ChatOpenAI(
model=assistant.model,
api_key=self.default_api_key, # 或从助手配置中读取特定 key
temperature=assistant.temperature
)
elif assistant.model.startswith("gemini"):
from langchain_google_genai import ChatGoogleGenerativeAI
return ChatGoogleGenerativeAI(
model=assistant.model,
api_key=self.default_api_key, # 或从助手配置中读取特定 key
temperature=assistant.temperature
)
else:
# 默认或抛出错误
print(f"警告: 模型 {assistant.model} 未明确支持,尝试使用 ChatOpenAI")
return ChatOpenAI(
model=assistant.model,
api_key=self.default_api_key,
temperature=assistant.temperature
)
async def get_ai_reply(self, user_message: str, session_id: str, assistant_id: str) -> str:
"""
获取 AI 对用户消息的回复 (使用指定助手和会话历史)
Args:
user_message (str): 用户发送的消息
session_id (str): 会话 ID
assistant_id (str): 使用的助手 ID
Returns:
str: AI 的回复文本
Raises:
ValueError: 如果找不到指定的助手
Exception: 如果调用 AI 服务时发生错误
"""
# 1. 获取助手配置
assistant = assistant_service_instance.get_assistant(assistant_id)
if not assistant:
raise ValueError(f"找不到助手 ID: {assistant_id}")
# 2. 获取或初始化当前会话的历史记录
current_chat_history = chat_history_db.get(session_id, [])
# 3. 构建 Prompt (包含动态系统提示)
prompt = ChatPromptTemplate.from_messages([
SystemMessage(content=assistant.system_prompt), # 使用助手的系统提示
MessagesPlaceholder(variable_name="chat_history"),
HumanMessage(content="{input}"),
])
# 4. 获取 LLM 实例
llm = self._get_llm(assistant)
# 5. 定义输出解析器
output_parser = StrOutputParser()
# 6. 构建 LCEL 链
chain = prompt | llm | output_parser
try:
# 7. 调用链获取回复
ai_response = await chain.ainvoke({
"input": user_message,
"chat_history": current_chat_history,
})
# 8. 更新会话历史记录
current_chat_history.append(HumanMessage(content=user_message))
current_chat_history.append(AIMessage(content=ai_response))
# 限制历史记录长度 (例如最近 10 条消息)
max_history_length = 10
if len(current_chat_history) > max_history_length:
chat_history_db[session_id] = current_chat_history[-max_history_length:]
else:
chat_history_db[session_id] = current_chat_history
return ai_response
except Exception as e:
print(f"调用 LangChain 时出错 (助手: {assistant_id}, 会话: {session_id}): {e}")
raise Exception(f"AI 服务暂时不可用: {e}")
# (可选) 添加一个简单的文本生成方法用于生成标题
async def generate_text(self, prompt_text: str, model_name: str = "gpt-3.5-turbo", temperature: float = 0.5) -> str:
"""使用指定模型生成文本 (用于标题等)"""
try:
# 使用一个临时的、可能更便宜的模型
temp_llm = ChatOpenAI(model=model_name, api_key=self.default_api_key, temperature=temperature)
response = await temp_llm.ainvoke(prompt_text)
return response.content
except Exception as e:
print(f"生成文本时出错: {e}")
return "无法生成标题" # 返回默认值或抛出异常
# --- 创建 ChatService 实例 ---
if not Config.GOOGLE_API_KEY:
raise ValueError("请在 .env 文件中设置 OPENAI_API_KEY")
chat_service_instance = ChatService(default_api_key=Config.GOOGLE_API_KEY)