69 lines
2.9 KiB
Python
69 lines
2.9 KiB
Python
# File: backend/app/api/v1/endpoints/chat.py (Update with DB session dependency)
|
|
# Description: 聊天功能的 API 路由 (使用数据库会话)
|
|
|
|
from fastapi import APIRouter, HTTPException, Depends, status
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from app.db.database import get_db_session
|
|
from app.models.pydantic_models import ChatRequest, ChatResponse, SessionCreateRequest
|
|
from app.services.chat_service import ChatService # Import class
|
|
from app.services.session_service import SessionService # Import class
|
|
import app.core.config as Config # Import API Key for ChatService instantiation
|
|
|
|
router = APIRouter()
|
|
|
|
# --- Dependency Injection ---
|
|
# Instantiate services here or use a more sophisticated dependency injection system
|
|
chat_service = ChatService(default_api_key=Config.GOOGLE_API_KEY)
|
|
session_service = SessionService()
|
|
|
|
@router.post("/", response_model=ChatResponse)
|
|
async def handle_chat_message(
|
|
request: ChatRequest,
|
|
db: AsyncSession = Depends(get_db_session) # Inject DB session
|
|
):
|
|
user_message = request.message
|
|
session_id = request.session_id
|
|
assistant_id = request.assistant_id
|
|
|
|
print(f"接收到消息: User='{user_message}', Session='{session_id}', Assistant='{assistant_id}'")
|
|
|
|
response_session_id = None
|
|
response_session_title = None
|
|
|
|
if session_id == 'temp-new-chat':
|
|
print("检测到临时新会话,正在创建...")
|
|
try:
|
|
create_req = SessionCreateRequest(assistant_id=assistant_id, first_message=user_message)
|
|
# Pass db session to the service method
|
|
created_session = await session_service.create_session(db, create_req)
|
|
session_id = created_session.id
|
|
response_session_id = created_session.id
|
|
response_session_title = created_session.title
|
|
print(f"新会话已创建: ID={session_id}, Title='{created_session.title}'")
|
|
except ValueError as e:
|
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
|
except Exception as e:
|
|
print(f"创建会话时出错: {e}")
|
|
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="创建会话失败")
|
|
|
|
try:
|
|
# Pass db session to the service method
|
|
ai_reply = await chat_service.get_ai_reply(
|
|
db=db,
|
|
user_message=user_message,
|
|
session_id=session_id,
|
|
assistant_id=assistant_id
|
|
)
|
|
print(f"发送 AI 回复: '{ai_reply}'")
|
|
return ChatResponse(
|
|
reply=ai_reply,
|
|
session_id=response_session_id,
|
|
session_title=response_session_title
|
|
)
|
|
except ValueError as e:
|
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
|
except Exception as e:
|
|
print(f"处理聊天消息时发生错误: {e}")
|
|
# The get_db_session dependency will handle rollback
|
|
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e))
|