128 lines
5.4 KiB
Python
128 lines
5.4 KiB
Python
# File: backend/app/services/chat_service.py (更新)
|
|
# Description: 封装 LangChain 聊天逻辑 (支持助手配置和会话历史)
|
|
|
|
from langchain_openai import ChatOpenAI
|
|
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
|
from langchain_core.output_parsers import StrOutputParser
|
|
from langchain_core.messages import HumanMessage, AIMessage, BaseMessage, SystemMessage
|
|
from typing import Dict, List, Optional
|
|
from app.services.assistant_service import assistant_service_instance # 获取助手配置
|
|
from app.models.pydantic_models import AssistantRead # 引入助手模型
|
|
import app.core.config as Config
|
|
|
|
# --- 更新内存管理 ---
|
|
# 使用字典存储不同会话的内存
|
|
# key: session_id (str), value: List[BaseMessage]
|
|
chat_history_db: Dict[str, List[BaseMessage]] = {}
|
|
|
|
class ChatService:
|
|
"""处理 AI 聊天交互的服务 (支持助手配置)"""
|
|
|
|
def __init__(self, default_api_key: str):
|
|
"""初始化时可传入默认 API Key"""
|
|
self.default_api_key = default_api_key
|
|
# 不再在 init 中创建固定的 LLM 和 chain
|
|
|
|
def _get_llm(self, assistant: AssistantRead) -> ChatOpenAI:
|
|
"""根据助手配置动态创建 LLM 实例"""
|
|
# TODO: 支持不同模型提供商 (Gemini, Anthropic etc.)
|
|
if assistant.model.startswith("gpt"):
|
|
return ChatOpenAI(
|
|
model=assistant.model,
|
|
api_key=self.default_api_key, # 或从助手配置中读取特定 key
|
|
temperature=assistant.temperature
|
|
)
|
|
elif assistant.model.startswith("gemini"):
|
|
from langchain_google_genai import ChatGoogleGenerativeAI
|
|
return ChatGoogleGenerativeAI(
|
|
model=assistant.model,
|
|
api_key=self.default_api_key, # 或从助手配置中读取特定 key
|
|
temperature=assistant.temperature
|
|
)
|
|
else:
|
|
# 默认或抛出错误
|
|
print(f"警告: 模型 {assistant.model} 未明确支持,尝试使用 ChatOpenAI")
|
|
return ChatOpenAI(
|
|
model=assistant.model,
|
|
api_key=self.default_api_key,
|
|
temperature=assistant.temperature
|
|
)
|
|
|
|
async def get_ai_reply(self, user_message: str, session_id: str, assistant_id: str) -> str:
|
|
"""
|
|
获取 AI 对用户消息的回复 (使用指定助手和会话历史)
|
|
Args:
|
|
user_message (str): 用户发送的消息
|
|
session_id (str): 会话 ID
|
|
assistant_id (str): 使用的助手 ID
|
|
Returns:
|
|
str: AI 的回复文本
|
|
Raises:
|
|
ValueError: 如果找不到指定的助手
|
|
Exception: 如果调用 AI 服务时发生错误
|
|
"""
|
|
# 1. 获取助手配置
|
|
assistant = assistant_service_instance.get_assistant(assistant_id)
|
|
if not assistant:
|
|
raise ValueError(f"找不到助手 ID: {assistant_id}")
|
|
|
|
# 2. 获取或初始化当前会话的历史记录
|
|
current_chat_history = chat_history_db.get(session_id, [])
|
|
|
|
# 3. 构建 Prompt (包含动态系统提示)
|
|
prompt = ChatPromptTemplate.from_messages([
|
|
SystemMessage(content=assistant.system_prompt), # 使用助手的系统提示
|
|
MessagesPlaceholder(variable_name="chat_history"),
|
|
HumanMessage(content="{input}"),
|
|
])
|
|
|
|
# 4. 获取 LLM 实例
|
|
llm = self._get_llm(assistant)
|
|
|
|
# 5. 定义输出解析器
|
|
output_parser = StrOutputParser()
|
|
|
|
# 6. 构建 LCEL 链
|
|
chain = prompt | llm | output_parser
|
|
|
|
try:
|
|
# 7. 调用链获取回复
|
|
ai_response = await chain.ainvoke({
|
|
"input": user_message,
|
|
"chat_history": current_chat_history,
|
|
})
|
|
|
|
# 8. 更新会话历史记录
|
|
current_chat_history.append(HumanMessage(content=user_message))
|
|
current_chat_history.append(AIMessage(content=ai_response))
|
|
# 限制历史记录长度 (例如最近 10 条消息)
|
|
max_history_length = 10
|
|
if len(current_chat_history) > max_history_length:
|
|
chat_history_db[session_id] = current_chat_history[-max_history_length:]
|
|
else:
|
|
chat_history_db[session_id] = current_chat_history
|
|
|
|
return ai_response
|
|
|
|
except Exception as e:
|
|
print(f"调用 LangChain 时出错 (助手: {assistant_id}, 会话: {session_id}): {e}")
|
|
raise Exception(f"AI 服务暂时不可用: {e}")
|
|
|
|
# (可选) 添加一个简单的文本生成方法用于生成标题
|
|
async def generate_text(self, prompt_text: str, model_name: str = "gpt-3.5-turbo", temperature: float = 0.5) -> str:
|
|
"""使用指定模型生成文本 (用于标题等)"""
|
|
try:
|
|
# 使用一个临时的、可能更便宜的模型
|
|
temp_llm = ChatOpenAI(model=model_name, api_key=self.default_api_key, temperature=temperature)
|
|
response = await temp_llm.ainvoke(prompt_text)
|
|
return response.content
|
|
except Exception as e:
|
|
print(f"生成文本时出错: {e}")
|
|
return "无法生成标题" # 返回默认值或抛出异常
|
|
|
|
|
|
# --- 创建 ChatService 实例 ---
|
|
if not Config.GOOGLE_API_KEY:
|
|
raise ValueError("请在 .env 文件中设置 OPENAI_API_KEY")
|
|
chat_service_instance = ChatService(default_api_key=Config.GOOGLE_API_KEY)
|