From f0863914c294f3464398ab0e5c091f7137424580 Mon Sep 17 00:00:00 2001 From: adrian Date: Wed, 30 Apr 2025 04:39:36 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E6=95=B0=E6=8D=AE=E7=AE=A1?= =?UTF-8?q?=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/app/api/v1/api.py | 10 +- backend/app/api/v1/endpoints/assistants.py | 58 +- backend/app/api/v1/endpoints/chat.py | 54 +- backend/app/api/v1/endpoints/messages.py | 33 + backend/app/api/v1/endpoints/sessions.py | 36 +- backend/app/core/config.py | 4 +- backend/app/db/database.py | 43 + backend/app/db/models.py | 52 ++ backend/app/main.py | 39 +- backend/app/models/pydantic_models.py | 25 +- backend/app/services/assistant_service.py | 115 +-- backend/app/services/chat_service.py | 166 ++-- backend/app/services/session_service.py | 114 ++- backend/cherryai.db | Bin 0 -> 45056 bytes frontend/app/chat/page.tsx | 948 +++++++++++++-------- frontend/app/layout.tsx | 5 +- frontend/lib/api.ts | 59 +- frontend/types/assistant.ts | 23 + 18 files changed, 1108 insertions(+), 676 deletions(-) create mode 100644 backend/app/api/v1/endpoints/messages.py create mode 100644 backend/app/db/database.py create mode 100644 backend/app/db/models.py create mode 100644 backend/cherryai.db create mode 100644 frontend/types/assistant.ts diff --git a/backend/app/api/v1/api.py b/backend/app/api/v1/api.py index 88b190a..c5f4c57 100644 --- a/backend/app/api/v1/api.py +++ b/backend/app/api/v1/api.py @@ -1,12 +1,12 @@ -# File: backend/app/api/v1/api.py (更新) +# File: backend/app/api/v1/api.py (Update) # Description: 聚合 v1 版本的所有 API 路由 from fastapi import APIRouter -from app.api.v1.endpoints import chat, assistants, sessions # 导入新路由 +from app.api.v1.endpoints import chat, assistants, sessions, messages # Import messages router api_router = APIRouter() api_router.include_router(chat.router, prefix="/chat", tags=["Chat"]) -api_router.include_router(assistants.router, prefix="/assistants", tags=["Assistants"]) # 添加助手路由 -api_router.include_router(sessions.router, prefix="/sessions", tags=["Sessions"]) # 添加会话路由 - +api_router.include_router(assistants.router, prefix="/assistants", tags=["Assistants"]) +api_router.include_router(sessions.router, prefix="/sessions", tags=["Sessions"]) +api_router.include_router(messages.router, prefix="/messages", tags=["Messages"]) # Add messages router diff --git a/backend/app/api/v1/endpoints/assistants.py b/backend/app/api/v1/endpoints/assistants.py index 899255f..e19a20d 100644 --- a/backend/app/api/v1/endpoints/assistants.py +++ b/backend/app/api/v1/endpoints/assistants.py @@ -1,39 +1,39 @@ -# File: backend/app/api/v1/endpoints/assistants.py (新建) -# Description: 助手的 API 路由 +# File: backend/app/api/v1/endpoints/assistants.py (Update with DB session dependency) +# Description: 助手的 API 路由 (使用数据库会话) from fastapi import APIRouter, HTTPException, Depends, status from typing import List +from sqlalchemy.ext.asyncio import AsyncSession +from app.db.database import get_db_session # Import DB session dependency from app.models.pydantic_models import AssistantRead, AssistantCreate, AssistantUpdate -from app.services.assistant_service import assistant_service_instance, AssistantService +from app.services.assistant_service import AssistantService # Import the class router = APIRouter() -# --- 依赖注入 AssistantService --- -def get_assistant_service() -> AssistantService: - return assistant_service_instance +# --- Dependency Injection for Service and DB Session --- +# Service instance can be created per request or globally +# For simplicity, let's create it here, but pass db session to methods +assistant_service = AssistantService() @router.post("/", response_model=AssistantRead, status_code=status.HTTP_201_CREATED) async def create_new_assistant( assistant_data: AssistantCreate, - service: AssistantService = Depends(get_assistant_service) + db: AsyncSession = Depends(get_db_session) # Inject DB session ): - """创建新助手""" - return service.create_assistant(assistant_data) + return await assistant_service.create_assistant(db, assistant_data) @router.get("/", response_model=List[AssistantRead]) async def read_all_assistants( - service: AssistantService = Depends(get_assistant_service) + db: AsyncSession = Depends(get_db_session) ): - """获取所有助手列表""" - return service.get_assistants() + return await assistant_service.get_assistants(db) @router.get("/{assistant_id}", response_model=AssistantRead) async def read_assistant_by_id( assistant_id: str, - service: AssistantService = Depends(get_assistant_service) + db: AsyncSession = Depends(get_db_session) ): - """根据 ID 获取特定助手""" - assistant = service.get_assistant(assistant_id) + assistant = await assistant_service.get_assistant(db, assistant_id) if not assistant: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="找不到指定的助手") return assistant @@ -42,10 +42,9 @@ async def read_assistant_by_id( async def update_existing_assistant( assistant_id: str, assistant_data: AssistantUpdate, - service: AssistantService = Depends(get_assistant_service) + db: AsyncSession = Depends(get_db_session) ): - """更新指定 ID 的助手""" - updated_assistant = service.update_assistant(assistant_id, assistant_data) + updated_assistant = await assistant_service.update_assistant(db, assistant_id, assistant_data) if not updated_assistant: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="找不到指定的助手") return updated_assistant @@ -53,14 +52,17 @@ async def update_existing_assistant( @router.delete("/{assistant_id}", status_code=status.HTTP_204_NO_CONTENT) async def delete_existing_assistant( assistant_id: str, - service: AssistantService = Depends(get_assistant_service) + db: AsyncSession = Depends(get_db_session) ): - """删除指定 ID 的助手""" - deleted = service.delete_assistant(assistant_id) - if not deleted: - # 根据服务层逻辑判断是找不到还是不允许删除 - assistant = service.get_assistant(assistant_id) - if assistant and assistant_id == 'asst-default': - raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="不允许删除默认助手") - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="找不到指定的助手") - # 成功删除,不返回内容 + # Handle potential error from service if trying to delete default + try: + deleted = await assistant_service.delete_assistant(db, assistant_id) + if not deleted: + # Check if it exists to differentiate 404 from 403 (or handle in service) + assistant = await assistant_service.get_assistant(db, assistant_id) + if assistant and assistant_id == 'asst-default': + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="不允许删除默认助手") + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="找不到指定的助手") + except Exception as e: # Catch other potential DB errors + print(f"删除助手时出错: {e}") + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="删除助手失败") diff --git a/backend/app/api/v1/endpoints/chat.py b/backend/app/api/v1/endpoints/chat.py index 66fc018..1cdfd0e 100644 --- a/backend/app/api/v1/endpoints/chat.py +++ b/backend/app/api/v1/endpoints/chat.py @@ -1,27 +1,26 @@ -# File: backend/app/api/v1/endpoints/chat.py (更新) -# Description: 聊天功能的 API 路由 (使用更新后的 ChatService) +# File: backend/app/api/v1/endpoints/chat.py (Update with DB session dependency) +# Description: 聊天功能的 API 路由 (使用数据库会话) from fastapi import APIRouter, HTTPException, Depends, status +from sqlalchemy.ext.asyncio import AsyncSession +from app.db.database import get_db_session from app.models.pydantic_models import ChatRequest, ChatResponse, SessionCreateRequest -from app.services.chat_service import chat_service_instance, ChatService -from app.services.session_service import session_service_instance, SessionService # 导入 SessionService +from app.services.chat_service import ChatService # Import class +from app.services.session_service import SessionService # Import class +import app.core.config as Config # Import API Key for ChatService instantiation router = APIRouter() -# --- 依赖注入 --- -def get_chat_service() -> ChatService: - return chat_service_instance - -def get_session_service() -> SessionService: - return session_service_instance +# --- Dependency Injection --- +# Instantiate services here or use a more sophisticated dependency injection system +chat_service = ChatService(default_api_key=Config.GOOGLE_API_KEY) +session_service = SessionService() @router.post("/", response_model=ChatResponse) async def handle_chat_message( request: ChatRequest, - chat_service: ChatService = Depends(get_chat_service), - session_service: SessionService = Depends(get_session_service) # 注入 SessionService + db: AsyncSession = Depends(get_db_session) # Inject DB session ): - """处理用户发送的聊天消息 (包含 assistantId 和 sessionId)""" user_message = request.message session_id = request.session_id assistant_id = request.assistant_id @@ -31,38 +30,39 @@ async def handle_chat_message( response_session_id = None response_session_title = None - # --- 处理临时新会话 --- if session_id == 'temp-new-chat': print("检测到临时新会话,正在创建...") try: - # 调用 SessionService 创建会话 create_req = SessionCreateRequest(assistant_id=assistant_id, first_message=user_message) - created_session = await session_service.create_session(create_req) - session_id = created_session.id # 使用新创建的会话 ID - response_session_id = created_session.id # 准备在响应中返回新 ID - response_session_title = created_session.title # 准备在响应中返回新标题 + # Pass db session to the service method + created_session = await session_service.create_session(db, create_req) + session_id = created_session.id + response_session_id = created_session.id + response_session_title = created_session.title print(f"新会话已创建: ID={session_id}, Title='{created_session.title}'") - except ValueError as e: # 助手不存在等错误 + except ValueError as e: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) - except Exception as e: # LLM 调用或其他错误 + except Exception as e: print(f"创建会话时出错: {e}") raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="创建会话失败") - # --- 调用聊天服务获取回复 --- try: + # Pass db session to the service method ai_reply = await chat_service.get_ai_reply( + db=db, user_message=user_message, - session_id=session_id, # 使用真实的或新创建的 session_id + session_id=session_id, assistant_id=assistant_id ) print(f"发送 AI 回复: '{ai_reply}'") return ChatResponse( reply=ai_reply, - session_id=response_session_id, # 返回新 ID (如果创建了) - session_title=response_session_title # 返回新标题 (如果创建了) + session_id=response_session_id, + session_title=response_session_title ) - except ValueError as e: # 助手不存在等错误 + except ValueError as e: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) - except Exception as e: # LLM 调用或其他错误 + except Exception as e: print(f"处理聊天消息时发生错误: {e}") + # The get_db_session dependency will handle rollback raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)) diff --git a/backend/app/api/v1/endpoints/messages.py b/backend/app/api/v1/endpoints/messages.py new file mode 100644 index 0000000..09d6588 --- /dev/null +++ b/backend/app/api/v1/endpoints/messages.py @@ -0,0 +1,33 @@ +# File: backend/app/api/v1/endpoints/messages.py (New) +# Description: API endpoint for fetching messages + +from fastapi import APIRouter, Depends, HTTPException, status, Query +from typing import List +from sqlalchemy.ext.asyncio import AsyncSession +from app.db.database import get_db_session +from app.models.pydantic_models import MessageRead +from app.db.models import MessageModel # Import DB model +from sqlalchemy.future import select + +router = APIRouter() + +@router.get("/session/{session_id}", response_model=List[MessageRead]) +async def read_messages_for_session( + session_id: str, + db: AsyncSession = Depends(get_db_session), + skip: int = Query(0, ge=0), # Offset for pagination + limit: int = Query(100, ge=1, le=500) # Limit number of messages +): + """获取指定会话的消息列表 (按时间顺序)""" + # TODO: Add check if session exists + stmt = ( + select(MessageModel) + .filter(MessageModel.session_id == session_id) + .order_by(MessageModel.order.asc()) # Fetch in chronological order + .offset(skip) + .limit(limit) + ) + result = await db.execute(stmt) + messages = result.scalars().all() + # Validate using Pydantic model before returning + return [MessageRead.model_validate(msg) for msg in messages] \ No newline at end of file diff --git a/backend/app/api/v1/endpoints/sessions.py b/backend/app/api/v1/endpoints/sessions.py index a4ddd53..b249581 100644 --- a/backend/app/api/v1/endpoints/sessions.py +++ b/backend/app/api/v1/endpoints/sessions.py @@ -1,47 +1,43 @@ -# File: backend/app/api/v1/endpoints/sessions.py (新建) -# Description: 会话管理的 API 路由 +# File: backend/app/api/v1/endpoints/sessions.py (Update with DB session dependency) +# Description: 会话管理的 API 路由 (使用数据库会话) from fastapi import APIRouter, HTTPException, Depends, status from typing import List +from sqlalchemy.ext.asyncio import AsyncSession +from app.db.database import get_db_session from app.models.pydantic_models import SessionRead, SessionCreateRequest, SessionCreateResponse -from app.services.session_service import session_service_instance, SessionService +from app.services.session_service import SessionService # Import the class router = APIRouter() - -def get_session_service() -> SessionService: - return session_service_instance +session_service = SessionService() # Create instance @router.post("/", response_model=SessionCreateResponse, status_code=status.HTTP_201_CREATED) async def create_new_session( session_data: SessionCreateRequest, - service: SessionService = Depends(get_session_service) + db: AsyncSession = Depends(get_db_session) # Inject DB session ): - """创建新会话并自动生成标题""" try: - return await service.create_session(session_data) + return await session_service.create_session(db, session_data) except ValueError as e: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) except Exception as e: - # 处理可能的 LLM 调用错误 print(f"创建会话时出错: {e}") raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="创建会话失败") @router.get("/assistant/{assistant_id}", response_model=List[SessionRead]) async def read_sessions_for_assistant( assistant_id: str, - service: SessionService = Depends(get_session_service) + db: AsyncSession = Depends(get_db_session) ): - """获取指定助手的所有会话列表""" - # TODO: 添加检查助手是否存在 - return service.get_sessions_by_assistant(assistant_id) + # Consider adding check if assistant exists first + return await session_service.get_sessions_by_assistant(db, assistant_id) @router.get("/{session_id}", response_model=SessionRead) async def read_session_by_id( session_id: str, - service: SessionService = Depends(get_session_service) + db: AsyncSession = Depends(get_db_session) ): - """获取单个会话信息""" - session = service.get_session(session_id) + session = await session_service.get_session(db, session_id) if not session: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="找不到指定的会话") return session @@ -49,10 +45,8 @@ async def read_session_by_id( @router.delete("/{session_id}", status_code=status.HTTP_204_NO_CONTENT) async def delete_existing_session( session_id: str, - service: SessionService = Depends(get_session_service) + db: AsyncSession = Depends(get_db_session) ): - """删除指定 ID 的会话""" - deleted = service.delete_session(session_id) + deleted = await session_service.delete_session(db, session_id) if not deleted: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="找不到指定的会话") - diff --git a/backend/app/core/config.py b/backend/app/core/config.py index d8327ff..4df53df 100644 --- a/backend/app/core/config.py +++ b/backend/app/core/config.py @@ -12,4 +12,6 @@ load_dotenv() OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY") # 如果使用 Google -# 可以在这里添加其他配置项 \ No newline at end of file +# Define the database URL (SQLite in this case) +# DATABASE_URL = "sqlite+aiosqlite:///./cherryai.db" # Use async driver +DATABASE_URL = os.getenv("DATABASE_URL", "sqlite+aiosqlite:///./cherryai.db") \ No newline at end of file diff --git a/backend/app/db/database.py b/backend/app/db/database.py new file mode 100644 index 0000000..5144691 --- /dev/null +++ b/backend/app/db/database.py @@ -0,0 +1,43 @@ +# File: backend/app/db/database.py (New - Database setup) +# Description: SQLAlchemy 数据库引擎和会话设置 + +from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession +from sqlalchemy.orm import sessionmaker, declarative_base +from app.core.config import DATABASE_URL + +# 创建异步数据库引擎 +# connect_args={"check_same_thread": False} is needed only for SQLite. +# It's not needed for other databases. +engine = create_async_engine(DATABASE_URL, echo=True, connect_args={"check_same_thread": False}) + +# 创建异步会话工厂 +# expire_on_commit=False prevents attributes from expiring after commit. +AsyncSessionFactory = sessionmaker( + bind=engine, + class_=AsyncSession, + expire_on_commit=False, + autocommit=False, + autoflush=False, +) + +# 创建数据模型的基础类 +Base = declarative_base() + +# --- Dependency to get DB session --- +async def get_db_session() -> AsyncSession: + """FastAPI dependency to get an async database session.""" + async with AsyncSessionFactory() as session: + try: + yield session + await session.commit() # Commit transaction if successful + except Exception: + await session.rollback() # Rollback on error + raise + finally: + await session.close() + +# --- Function to create tables (call this on startup) --- +async def create_db_and_tables(): + async with engine.begin() as conn: + # await conn.run_sync(Base.metadata.drop_all) # Use drop_all carefully in dev + await conn.run_sync(Base.metadata.create_all) \ No newline at end of file diff --git a/backend/app/db/models.py b/backend/app/db/models.py new file mode 100644 index 0000000..33399ab --- /dev/null +++ b/backend/app/db/models.py @@ -0,0 +1,52 @@ +# File: backend/app/db/models.py (New - Database models) +# Description: SQLAlchemy ORM 模型定义 + +from sqlalchemy import Column, String, Float, ForeignKey, Text, DateTime, Integer +from sqlalchemy.orm import relationship +from sqlalchemy.sql import func # For default timestamps +from app.db.database import Base +import uuid +from datetime import datetime, timezone + +def generate_uuid(): + return str(uuid.uuid4()) + +class AssistantModel(Base): + __tablename__ = "assistants" + + id = Column(String, primary_key=True, default=generate_uuid) + name = Column(String(50), nullable=False, index=True) + description = Column(String(200), nullable=True) + avatar = Column(String(5), nullable=True) + system_prompt = Column(Text, nullable=False) + model = Column(String, nullable=False) + temperature = Column(Float, nullable=False, default=0.7) + created_at = Column(DateTime(timezone=True), server_default=func.now()) + updated_at = Column(DateTime(timezone=True), onupdate=func.now()) + + sessions = relationship("SessionModel", back_populates="assistant", cascade="all, delete-orphan") + +class SessionModel(Base): + __tablename__ = "sessions" + + id = Column(String, primary_key=True, default=generate_uuid) + title = Column(String(100), nullable=False, default="New Chat") + assistant_id = Column(String, ForeignKey("assistants.id"), nullable=False, index=True) + created_at = Column(DateTime(timezone=True), server_default=func.now()) + updated_at = Column(DateTime(timezone=True), onupdate=func.now(), index=True) # Index for sorting + + assistant = relationship("AssistantModel", back_populates="sessions") + messages = relationship("MessageModel", back_populates="session", cascade="all, delete-orphan", order_by="MessageModel.created_at") # Order messages by time + +class MessageModel(Base): + __tablename__ = "messages" + + id = Column(String, primary_key=True, default=generate_uuid) + session_id = Column(String, ForeignKey("sessions.id"), nullable=False, index=True) + sender = Column(String(10), nullable=False) # 'user' or 'ai' or 'system' + text = Column(Text, nullable=False) + order = Column(Integer, nullable=False) # Explicit order within session + created_at = Column(DateTime(timezone=True), server_default=func.now()) + + session = relationship("SessionModel", back_populates="messages") + diff --git a/backend/app/main.py b/backend/app/main.py index 45e4ef2..45f25b4 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -1,20 +1,31 @@ -# File: backend/app/main.py (确认 load_dotenv 调用位置) -# Description: FastAPI 应用入口 +# File: backend/app/main.py (Update - Add startup event) +# Description: FastAPI 应用入口 (添加数据库初始化) from fastapi import FastAPI from starlette.middleware.cors import CORSMiddleware from app.api.v1.api import api_router as api_router_v1 -# 确保在创建 FastAPI 实例之前加载环境变量 -from app.core.config import OPENAI_API_KEY # 导入会触发 load_dotenv +import app.core.config # Ensure config is loaded +from app.db.database import create_db_and_tables # Import table creation function +from contextlib import asynccontextmanager -# 创建 FastAPI 应用实例 -app = FastAPI(title="CherryAI Backend", version="0.1.0") +# --- Lifespan context manager for startup/shutdown events --- +@asynccontextmanager +async def lifespan(app: FastAPI): + # Startup actions + print("应用程序启动中...") + await create_db_and_tables() # Create database tables on startup + print("数据库表已检查/创建。") + # You can add the default assistant creation here if needed, + # but doing it in the service/model definition might be simpler for defaults. + yield + # Shutdown actions + print("应用程序关闭中...") -# --- 配置 CORS --- -origins = [ - "http://localhost:3000", - "http://127.0.0.1:3000", -] +# Create FastAPI app with lifespan context manager +app = FastAPI(title="CherryAI Backend", version="0.1.0", lifespan=lifespan) + +# --- CORS Middleware --- +origins = [ "http://localhost:3000", "http://127.0.0.1:3000" ] app.add_middleware( CORSMiddleware, allow_origins=origins, @@ -23,10 +34,10 @@ app.add_middleware( allow_headers=["*"], ) -# --- 挂载 API 路由 --- +# --- API Routers --- app.include_router(api_router_v1, prefix="/api/v1") -# --- 根路径 --- +# --- Root Endpoint --- @app.get("/", tags=["Root"]) async def read_root(): - return {"message": "欢迎来到 CherryAI 后端!"} + return {"message": "欢迎来到 CherryAI 后端!"} \ No newline at end of file diff --git a/backend/app/models/pydantic_models.py b/backend/app/models/pydantic_models.py index d001128..c5278c2 100644 --- a/backend/app/models/pydantic_models.py +++ b/backend/app/models/pydantic_models.py @@ -1,9 +1,10 @@ -# File: backend/app/models/pydantic_models.py (更新) +# File: backend/app/models/pydantic_models.py (Update Read models, add Message models) # Description: Pydantic 模型定义 API 数据结构 from pydantic import BaseModel, Field from typing import Optional, List -import uuid # 用于生成唯一 ID +import uuid +from datetime import datetime # Use datetime directly # --- Assistant Models --- @@ -33,7 +34,8 @@ class AssistantUpdate(BaseModel): class AssistantRead(AssistantBase): """读取助手信息时返回的模型 (包含 ID)""" id: str = Field(..., description="助手唯一 ID") - + created_at: datetime # Add timestamps + updated_at: Optional[datetime] = None class Config: from_attributes = True # Pydantic v2: orm_mode = True @@ -70,7 +72,22 @@ class SessionRead(BaseModel): id: str title: str assistant_id: str - created_at: str + created_at: datetime # Use datetime + updated_at: Optional[datetime] = None + + class Config: + from_attributes = True + +# --- Message Models (New) --- +class MessageBase(BaseModel): + sender: str # 'user' or 'ai' + text: str + +class MessageRead(MessageBase): + id: str + session_id: str + order: int + created_at: datetime class Config: from_attributes = True \ No newline at end of file diff --git a/backend/app/services/assistant_service.py b/backend/app/services/assistant_service.py index ae4dfd4..50cefb9 100644 --- a/backend/app/services/assistant_service.py +++ b/backend/app/services/assistant_service.py @@ -1,73 +1,76 @@ -# File: backend/app/services/assistant_service.py (新建) -# Description: 管理助手数据的服务 (内存实现) +# File: backend/app/services/assistant_service.py (Update with DB) +# Description: 管理助手数据的服务 (使用 SQLAlchemy) -from typing import Dict, List, Optional -from app.models.pydantic_models import AssistantRead, AssistantCreate, AssistantUpdate -import uuid - -# 使用字典作为内存数据库存储助手 -# key: assistant_id (str), value: AssistantRead object -assistants_db: Dict[str, AssistantRead] = {} - -# 添加默认助手 (确保 ID 与前端 Mock 一致) -default_assistant = AssistantRead( - id='asst-default', - name='默认助手', - description='通用聊天助手', - avatar='🤖', - system_prompt='你是一个乐于助人的 AI 助手。', - model='gpt-3.5-turbo', - temperature=0.7 -) -assistants_db[default_assistant.id] = default_assistant +from typing import List, Optional +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.future import select +from sqlalchemy import update as sqlalchemy_update, delete as sqlalchemy_delete +from app.db.models import AssistantModel +from app.models.pydantic_models import AssistantCreate, AssistantUpdate, AssistantRead class AssistantService: - """助手数据的 CRUD 服务""" + """助手数据的 CRUD 服务 (数据库版)""" - def get_assistants(self) -> List[AssistantRead]: + async def get_assistants(self, db: AsyncSession) -> List[AssistantRead]: """获取所有助手""" - return list(assistants_db.values()) + result = await db.execute(select(AssistantModel).order_by(AssistantModel.name)) + assistants = result.scalars().all() + return [AssistantRead.model_validate(a) for a in assistants] # Use model_validate in Pydantic v2 - def get_assistant(self, assistant_id: str) -> Optional[AssistantRead]: + async def get_assistant(self, db: AsyncSession, assistant_id: str) -> Optional[AssistantRead]: """根据 ID 获取单个助手""" - return assistants_db.get(assistant_id) + result = await db.execute(select(AssistantModel).filter(AssistantModel.id == assistant_id)) + assistant = result.scalars().first() + return AssistantRead.model_validate(assistant) if assistant else None - def create_assistant(self, assistant_data: AssistantCreate) -> AssistantRead: + async def create_assistant(self, db: AsyncSession, assistant_data: AssistantCreate) -> AssistantRead: """创建新助手""" - new_id = f"asst-{uuid.uuid4()}" # 生成唯一 ID - new_assistant = AssistantRead(id=new_id, **assistant_data.model_dump()) - assistants_db[new_id] = new_assistant - print(f"助手已创建: {new_id} - {new_assistant.name}") - return new_assistant + # 使用 Pydantic 模型创建 DB 模型实例 + db_assistant = AssistantModel(**assistant_data.model_dump()) + # ID will be generated by default in the model + db.add(db_assistant) + await db.flush() # Flush to get the generated ID and defaults + await db.refresh(db_assistant) # Refresh to load all attributes + print(f"助手已创建 (DB): {db_assistant.id} - {db_assistant.name}") + return AssistantRead.model_validate(db_assistant) - def update_assistant(self, assistant_id: str, assistant_data: AssistantUpdate) -> Optional[AssistantRead]: + async def update_assistant(self, db: AsyncSession, assistant_id: str, assistant_data: AssistantUpdate) -> Optional[AssistantRead]: """更新现有助手""" - existing_assistant = assistants_db.get(assistant_id) - if not existing_assistant: - return None + update_values = assistant_data.model_dump(exclude_unset=True) + if not update_values: + # If nothing to update, just fetch and return the existing one + return await self.get_assistant(db, assistant_id) - # 使用 Pydantic 的 model_copy 和 update 来更新字段 - update_data = assistant_data.model_dump(exclude_unset=True) # 只获取设置了值的字段 - if update_data: - updated_assistant = existing_assistant.model_copy(update=update_data) - assistants_db[assistant_id] = updated_assistant - print(f"助手已更新: {assistant_id}") - return updated_assistant - return existing_assistant # 如果没有更新任何字段,返回原始助手 + # Execute update statement + stmt = ( + sqlalchemy_update(AssistantModel) + .where(AssistantModel.id == assistant_id) + .values(**update_values) + .returning(AssistantModel) # Return the updated row + ) + result = await db.execute(stmt) + updated_assistant = result.scalars().first() - def delete_assistant(self, assistant_id: str) -> bool: + if updated_assistant: + await db.flush() + await db.refresh(updated_assistant) + print(f"助手已更新 (DB): {assistant_id}") + return AssistantRead.model_validate(updated_assistant) + return None # Assistant not found + + async def delete_assistant(self, db: AsyncSession, assistant_id: str) -> bool: """删除助手""" - if assistant_id in assistants_db: - # 添加逻辑:不允许删除默认助手 - if assistant_id == 'asst-default': - print("尝试删除默认助手 - 操作被阻止") - return False # 或者抛出特定异常 - del assistants_db[assistant_id] - print(f"助手已删除: {assistant_id}") - # TODO: 在实际应用中,还需要删除关联的会话和消息 + # Prevent deleting default assistant + if assistant_id == 'asst-default': # Assuming 'asst-default' is a known ID + print("尝试删除默认助手 - 操作被阻止") + return False + + stmt = sqlalchemy_delete(AssistantModel).where(AssistantModel.id == assistant_id) + result = await db.execute(stmt) + if result.rowcount > 0: + await db.flush() + print(f"助手已删除 (DB): {assistant_id}") + # Deletion of sessions/messages handled by cascade="all, delete-orphan" return True return False -# 创建服务实例 -assistant_service_instance = AssistantService() - diff --git a/backend/app/services/chat_service.py b/backend/app/services/chat_service.py index ce2770a..f4151e4 100644 --- a/backend/app/services/chat_service.py +++ b/backend/app/services/chat_service.py @@ -1,127 +1,137 @@ -# File: backend/app/services/chat_service.py (更新) -# Description: 封装 LangChain 聊天逻辑 (支持助手配置和会话历史) +# File: backend/app/services/chat_service.py (Update with DB for history) +# Description: 封装 LangChain 聊天逻辑 (使用数据库存储和检索消息) from langchain_openai import ChatOpenAI +from langchain_google_genai import ChatGoogleGenerativeAI from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain_core.output_parsers import StrOutputParser from langchain_core.messages import HumanMessage, AIMessage, BaseMessage, SystemMessage -from typing import Dict, List, Optional -from app.services.assistant_service import assistant_service_instance # 获取助手配置 -from app.models.pydantic_models import AssistantRead # 引入助手模型 -import app.core.config as Config - -# --- 更新内存管理 --- -# 使用字典存储不同会话的内存 -# key: session_id (str), value: List[BaseMessage] -chat_history_db: Dict[str, List[BaseMessage]] = {} +from typing import Dict, List, Optional, Tuple +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.future import select +from app.db.models import MessageModel, AssistantModel # Import DB models +from app.services.assistant_service import AssistantService # Use class directly +from app.models.pydantic_models import AssistantRead class ChatService: - """处理 AI 聊天交互的服务 (支持助手配置)""" + """处理 AI 聊天交互的服务 (使用数据库历史)""" def __init__(self, default_api_key: str): - """初始化时可传入默认 API Key""" self.default_api_key = default_api_key - # 不再在 init 中创建固定的 LLM 和 chain + self.assistant_service = AssistantService() # Instantiate assistant service def _get_llm(self, assistant: AssistantRead) -> ChatOpenAI: - """根据助手配置动态创建 LLM 实例""" - # TODO: 支持不同模型提供商 (Gemini, Anthropic etc.) + # ... (remains the same) ... if assistant.model.startswith("gpt"): - return ChatOpenAI( - model=assistant.model, - api_key=self.default_api_key, # 或从助手配置中读取特定 key - temperature=assistant.temperature - ) + return ChatOpenAI(model=assistant.model, api_key=self.default_api_key, temperature=assistant.temperature) elif assistant.model.startswith("gemini"): - from langchain_google_genai import ChatGoogleGenerativeAI return ChatGoogleGenerativeAI( model=assistant.model, api_key=self.default_api_key, # 或从助手配置中读取特定 key temperature=assistant.temperature ) else: - # 默认或抛出错误 print(f"警告: 模型 {assistant.model} 未明确支持,尝试使用 ChatOpenAI") - return ChatOpenAI( - model=assistant.model, - api_key=self.default_api_key, - temperature=assistant.temperature - ) + return ChatOpenAI(model=assistant.model, api_key=self.default_api_key, temperature=assistant.temperature) - async def get_ai_reply(self, user_message: str, session_id: str, assistant_id: str) -> str: - """ - 获取 AI 对用户消息的回复 (使用指定助手和会话历史) - Args: - user_message (str): 用户发送的消息 - session_id (str): 会话 ID - assistant_id (str): 使用的助手 ID - Returns: - str: AI 的回复文本 - Raises: - ValueError: 如果找不到指定的助手 - Exception: 如果调用 AI 服务时发生错误 - """ + + async def _get_chat_history(self, db: AsyncSession, session_id: str, limit: int = 10) -> Tuple[List[BaseMessage], int]: + """从数据库加载指定会话的历史消息 (按 order 排序)""" + stmt = ( + select(MessageModel) + .filter(MessageModel.session_id == session_id) + .order_by(MessageModel.order.desc()) # Get latest first + .limit(limit) + ) + result = await db.execute(stmt) + db_messages = result.scalars().all() + + # Convert to LangChain messages (in correct order: oldest to newest) + history: List[BaseMessage] = [] + max_order = 0 + for msg in reversed(db_messages): # Reverse to get chronological order + if msg.sender == 'user': + history.append(HumanMessage(content=msg.text)) + elif msg.sender == 'ai': + history.append(AIMessage(content=msg.text)) + # Add handling for 'system' if needed + max_order = max(max_order, msg.order) # Keep track of the latest order number + + return history, max_order + + async def _save_message(self, db: AsyncSession, session_id: str, sender: str, text: str, order: int): + """将消息保存到数据库""" + db_message = MessageModel( + session_id=session_id, + sender=sender, + text=text, + order=order + ) + db.add(db_message) + await db.flush() # Ensure it's added before potential commit + print(f"消息已保存 (DB): Session={session_id}, Order={order}, Sender={sender}") + + + async def get_ai_reply(self, db: AsyncSession, user_message: str, session_id: str, assistant_id: str) -> str: + """获取 AI 回复,并保存用户消息和 AI 回复到数据库""" # 1. 获取助手配置 - assistant = assistant_service_instance.get_assistant(assistant_id) + assistant = await self.assistant_service.get_assistant(db, assistant_id) if not assistant: raise ValueError(f"找不到助手 ID: {assistant_id}") - # 2. 获取或初始化当前会话的历史记录 - current_chat_history = chat_history_db.get(session_id, []) + # 2. 获取历史记录和下一个序号 + current_chat_history, last_order = await self._get_chat_history(db, session_id) + user_message_order = last_order + 1 + ai_message_order = last_order + 2 - # 3. 构建 Prompt (包含动态系统提示) + # 3. 构建 Prompt prompt = ChatPromptTemplate.from_messages([ - SystemMessage(content=assistant.system_prompt), # 使用助手的系统提示 + SystemMessage(content=assistant.system_prompt), MessagesPlaceholder(variable_name="chat_history"), - HumanMessage(content="{input}"), + HumanMessage(content=user_message), ]) - # 4. 获取 LLM 实例 + # 4. 获取 LLM llm = self._get_llm(assistant) - - # 5. 定义输出解析器 output_parser = StrOutputParser() - - # 6. 构建 LCEL 链 chain = prompt | llm | output_parser try: - # 7. 调用链获取回复 + # --- Save user message BEFORE calling LLM --- + await self._save_message(db, session_id, 'user', user_message, user_message_order) + + # 5. 调用链获取回复 ai_response = await chain.ainvoke({ "input": user_message, - "chat_history": current_chat_history, + "chat_history": current_chat_history, # Pass history fetched from DB }) - # 8. 更新会话历史记录 - current_chat_history.append(HumanMessage(content=user_message)) - current_chat_history.append(AIMessage(content=ai_response)) - # 限制历史记录长度 (例如最近 10 条消息) - max_history_length = 10 - if len(current_chat_history) > max_history_length: - chat_history_db[session_id] = current_chat_history[-max_history_length:] - else: - chat_history_db[session_id] = current_chat_history + # --- Save AI response AFTER getting it --- + await self._save_message(db, session_id, 'ai', ai_response, ai_message_order) + # Note: We don't need to manage history in memory anymore (chat_history_db removed) return ai_response except Exception as e: + # Consider rolling back the user message save if LLM call fails, + # although often it's better to keep the user message. + # await db.rollback() # Handled by get_db_session dependency on error print(f"调用 LangChain 时出错 (助手: {assistant_id}, 会话: {session_id}): {e}") raise Exception(f"AI 服务暂时不可用: {e}") - # (可选) 添加一个简单的文本生成方法用于生成标题 - async def generate_text(self, prompt_text: str, model_name: str = "gpt-3.5-turbo", temperature: float = 0.5) -> str: - """使用指定模型生成文本 (用于标题等)""" - try: - # 使用一个临时的、可能更便宜的模型 - temp_llm = ChatOpenAI(model=model_name, api_key=self.default_api_key, temperature=temperature) - response = await temp_llm.ainvoke(prompt_text) - return response.content - except Exception as e: - print(f"生成文本时出错: {e}") - return "无法生成标题" # 返回默认值或抛出异常 + async def generate_text(self, prompt_text: str, model_name: str = "gemini-2.0-flash", temperature: float = 0.5) -> str: + # ... (remains the same) ... + try: + temp_llm = ChatGoogleGenerativeAI( + model=model_name, + api_key=self.default_api_key, # 或从助手配置中读取特定 key + temperature=temperature + ) + response = await temp_llm.ainvoke(prompt_text) + return response.content + except Exception as e: + print(f"生成文本时出错: {e}") + return "无法生成标题" -# --- 创建 ChatService 实例 --- -if not Config.GOOGLE_API_KEY: - raise ValueError("请在 .env 文件中设置 OPENAI_API_KEY") -chat_service_instance = ChatService(default_api_key=Config.GOOGLE_API_KEY) +# ChatService instance is now created where needed or injected, no global instance here. diff --git a/backend/app/services/session_service.py b/backend/app/services/session_service.py index 579cb76..77b2d37 100644 --- a/backend/app/services/session_service.py +++ b/backend/app/services/session_service.py @@ -1,82 +1,74 @@ -# File: backend/app/services/session_service.py (新建) -# Description: 管理会话数据的服务 (内存实现) +# File: backend/app/services/session_service.py (Update with DB) +# Description: 管理会话数据的服务 (使用 SQLAlchemy) -from typing import Dict, List, Optional -from app.models.pydantic_models import SessionRead, SessionCreateRequest, SessionCreateResponse, AssistantRead -from app.services.assistant_service import assistant_service_instance # 需要获取助手信息 +from typing import List, Optional +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.future import select +from sqlalchemy import delete as sqlalchemy_delete +from app.db.models import SessionModel, AssistantModel # Import DB models +from app.models.pydantic_models import SessionRead, SessionCreateRequest, SessionCreateResponse from datetime import datetime, timezone -import uuid -# 导入 ChatService 以调用 LLM 生成标题 (避免循环导入,考虑重构) -# from app.services.chat_service import chat_service_instance - -# 内存数据库存储会话 -# key: session_id (str), value: SessionRead object -sessions_db: Dict[str, SessionRead] = {} - +# Import ChatService for title generation (consider refactoring later) +from app.services.chat_service import ChatService +import app.core.config as Config +chat_service_instance = ChatService(Config.GOOGLE_API_KEY) class SessionService: - """会话数据的 CRUD 及标题生成服务""" + """会话数据的 CRUD 及标题生成服务 (数据库版)""" - async def create_session(self, session_data: SessionCreateRequest) -> SessionCreateResponse: + async def create_session(self, db: AsyncSession, session_data: SessionCreateRequest) -> SessionCreateResponse: """创建新会话并生成标题""" - assistant = assistant_service_instance.get_assistant(session_data.assistant_id) + # 检查助手是否存在 + result = await db.execute(select(AssistantModel).filter(AssistantModel.id == session_data.assistant_id)) + assistant = result.scalars().first() if not assistant: raise ValueError("指定的助手不存在") - new_id = f"session-{uuid.uuid4()}" - created_time = datetime.now(timezone.utc) + # --- 调用 LLM 生成标题 --- + try: + title_prompt = f"根据以下用户第一条消息,为此对话生成一个简洁的标题(不超过10个字):\n\n{session_data.first_message}" + generated_title = await chat_service_instance.generate_text(title_prompt) + except Exception as e: + print(f"生成会话标题时出错: {e}") + generated_title = f"关于 \"{session_data.first_message[:15]}...\"" # Fallback + # --- 生成结束 --- - # --- TODO: 调用 LLM 生成标题 --- - # title_prompt = f"根据以下用户第一条消息,为此对话生成一个简洁的标题(不超过10个字):\n\n{session_data.first_message}" - # generated_title = await chat_service_instance.generate_text(title_prompt) # 需要一个简单的文本生成方法 - - # 模拟标题生成 - generated_title = f"关于 \"{session_data.first_message[:15]}...\"" - print(f"为新会话 {new_id} 生成标题: {generated_title}") - # --- 模拟结束 --- - - new_session = SessionRead( - id=new_id, + db_session = SessionModel( title=generated_title, - assistant_id=session_data.assistant_id, - created_at=created_time.isoformat() # 存储 ISO 格式字符串 + assistant_id=session_data.assistant_id + # ID and created_at have defaults ) - sessions_db[new_id] = new_session - print(f"会话已创建: {new_id}") + db.add(db_session) + await db.flush() + await db.refresh(db_session) + print(f"会话已创建 (DB): {db_session.id}") return SessionCreateResponse( - id=new_session.id, - title=new_session.title, - assistant_id=new_session.assistant_id, - created_at=new_session.created_at + id=db_session.id, + title=db_session.title, + assistant_id=db_session.assistant_id, + created_at=db_session.created_at.isoformat() # Use datetime from DB model ) - def get_sessions_by_assistant(self, assistant_id: str) -> List[SessionRead]: + async def get_sessions_by_assistant(self, db: AsyncSession, assistant_id: str) -> List[SessionRead]: """获取指定助手的所有会话""" - return [s for s in sessions_db.values() if s.assistant_id == assistant_id] + stmt = select(SessionModel).filter(SessionModel.assistant_id == assistant_id).order_by(SessionModel.updated_at.desc()) # Order by update time + result = await db.execute(stmt) + sessions = result.scalars().all() + return [SessionRead.model_validate(s) for s in sessions] - def get_session(self, session_id: str) -> Optional[SessionRead]: + async def get_session(self, db: AsyncSession, session_id: str) -> Optional[SessionRead]: """获取单个会话""" - return sessions_db.get(session_id) + result = await db.execute(select(SessionModel).filter(SessionModel.id == session_id)) + session = result.scalars().first() + return SessionRead.model_validate(session) if session else None - def delete_session(self, session_id: str) -> bool: + async def delete_session(self, db: AsyncSession, session_id: str) -> bool: """删除会话""" - if session_id in sessions_db: - del sessions_db[session_id] - print(f"会话已删除: {session_id}") - # TODO: 删除关联的消息 + stmt = sqlalchemy_delete(SessionModel).where(SessionModel.id == session_id) + result = await db.execute(stmt) + if result.rowcount > 0: + await db.flush() + print(f"会话已删除 (DB): {session_id}") + # Deletion of messages handled by cascade return True - return False - - def delete_sessions_by_assistant(self, assistant_id: str) -> int: - """删除指定助手的所有会话""" - ids_to_delete = [s.id for s in sessions_db.values() if s.assistant_id == assistant_id] - count = 0 - for session_id in ids_to_delete: - if self.delete_session(session_id): - count += 1 - print(f"删除了助手 {assistant_id} 的 {count} 个会话") - return count - - -# 创建服务实例 -session_service_instance = SessionService() + return False \ No newline at end of file diff --git a/backend/cherryai.db b/backend/cherryai.db new file mode 100644 index 0000000000000000000000000000000000000000..6e9cbc4fc5eb3f533dbf0fd77acb7985cf000683 GIT binary patch literal 45056 zcmeI5-BTOa6~HAhVB`-aGtJP>)YEIyOo4h+?XEuQq%Czw>ZxO=1lo8~7p>6BQ3`?x zxlN{%jsPRTM&Q`SU*HcMY!f>cICg9-5Ps-C(KpW|x-01^eau52dhV_kt%OOGnKYSr zjs~ICz2}~L?(dv)_k&&TYk$}qjL0O=8}5-JL^M5RGMi1mA;e@dS>e3}-ukNsHa6*B z;L~i_-e}Wmdj6jWwpqV0Su8&=Z4X)hwf(8Bscnz8?%P5w?`^pYh4?}ONB{{S0VIF~ zkN^_+&Iuf{)YN%)mjr!$z<-)EV^aMO${awYIdkl0Exg zY<-81s_SV-ypFwARZx|ZY!vNb` zUv4e#uPZD|5!v4fN6s$-g!VnJw2~KFUux;t+fH_Ow6(SFYwu)*ueGMb@$|2*Aij}V&4+V5<0aQ+y7!7GO2FhB+TBuZgS*YkpI(dDo`p=cs zb^K2AVcmfRjh!K>M=pHbRO-A!l`6^HRwyD3taamHo4KZL=T7sHUe#fR`odO4k<$w0 zU*a~VTK7kTqr_=`xxXtM?2Ev(wI(CjY-;LBf0H6oVM1`Ih5hgLN93N)zHo0(-(%*3 zo?gG)t!pasKD6qS!&2mMSSBy+-QQ9?kNm%4!_p{LRM)+{bED~?K(Lc9Y*)fgW;5H- zY_tB|RI&YlY5RfgUwr2~9y9<6AOR$R1dsp{Kmter2_OL^fCRoz1P1EOmJWX4$ME~P zZQDAy8c`Gkw@VNMe$Fol0%zwvF3#=419V~oyz*8ew7r#wScG#0R=!n$CxE?EEl zeF`$P6$u~#B!C2v01`j~NB{{S0VIF~kiZj6prNwET+Zc+il0?>)RuEyK|B9{WU~Ip z`sfMfhC@RFNB{{S0VIF~kN^@u0!RP}AOR%sWDt0!lCRZnq0nxX&~8b1s`5Fr@z#Y3 zYvpr!F->3_jNWWLV}c)iApsW&|Eutx)U z!DsiulM=oznX`KWPR?T!ZJdMmcXfFLhwSl4@UR5Enkwgy{l(8g*qqIFk^TMO47dK{ z-T&XNgRwUxfCP{L5Z|>-Ldj4&*7_dnJw<92PU4qZg z*=;hcMStnvK!awIn~7$x4dmruEsoJM z+1dHLXqyZxUxxD1ng~5U!B*^8lhV5jkrzm}98!${<$>NX36fxlgr(3Sxgn$)M#!^7 zS(sEln#~Qw=|n6Ct86UBl$j`&9s-S;tWllq+{{CuDR>!N~bUj9h!gOFf= z5qaaSgXCA*2KhPR4{E6LxS(>Ep$^8ZwonE06$e3ESmnm*ES6_UW-ggtJ(f8&4?eK; zyxZ9gtL8$>H=OoeM1;5C*hP2&KH<&tyNJhj&|+a8g=L8IvPL)OH8ff*ec@mz(onzN zG0pY6h(SRkqmx}dNk@{+~|F{kW7l=0m9i|eX9vUo6g4OU1f;GKzUX9lq8pGVsJROaJ7A=DF zA*ZGt3*TEGWM&gkn@+_us-XQp3=|MTm~@RpBcKRHIDE6V8ybY&7c&LpDB-2pnYsLd4VrMWdstVwjGghCywSOvIuJ880Q6JuF%nGx|{^~TDoXrTIdsRKKs4K-m{BCY(PbcsLv9BdEtT9 zIQwuqJ3FqK#K=Cl2hp5aKFKD{TW9IuJVY7s8s-fa8Mllb<;i)~iAPU!ZDQGAq%718 z`yl`~aVc{=#)fM#*1&R55S+P@IGu|Fb1gyF@0Y;r#|FMO9A$hgGkjecAEAl!;G^7Z z5`y^il~GVmmzG#A#7ESUG2`sdd_8@aW$eeZ%FLJo+jFdWUS{(B%&ALEzdCEt#oPIj zCdQS~LJ{jh%}q=D`EG$Ys2I5`P=8ZBBw`rr<|FjCU+ zFM?jpth_M!m};2(Gan{PtbtB>wbRKP`Pl&oPc3GWFu&ep@UL8e30Z3gyDBg{x(u+~0SIOD`SOa!2h7tYIT4Hi#)+ zPG%DcO%TE`ck?V`QH%h@;9vp{b?C#Jx#%U0&S%j9nmPdqt$k%CT#o5ggVn{)2n!;t zpf*rdRnyV&%pSAYbs5)_se0R zp&gsj*f}lqjX7Sm&JMvb)Ub&iA`d@p7VVCJ!{(G+oFqzi&hEBLoQE0Wvh(%;FWKC( zQ!2+0DVSZjuMEyBlgHDm({%WLI&~5t67N5Pp}xbB_h8H?HkW@zBSH|M7*qi2L zXv89SCaS=31i;i3EG6LIs)64$?5q{cl#>hT<*V9J3*;cgH>4(vSV?@M9VdZUHt`|o zr%MSYKN?dm9tStURGnRz0tcm6&eHMIEZOr-VegVVlgM2;ThbJQFLUcuv7j<^2aY96 z1ZiTNo<346m|k9{M`zay!qmkCEeycSCAvDS4P775z@}nX!J>N(_dr!;K4knZx7}s4 zNt}k{3ZDEU23m&?`V3-IL_A_d`q2VBoUE!?N?AI&a|0UCG;pT&df6>wQL zI;_P8h_PCBGziNCPM5?#t!Nh{yOZ}zoGiNC98eTFj~I|Rr`zKc zMUST|AiK-4n4x=9?(Xhg-5jkN^@u0!RP}AOR$R z1dsp{Kmt!L0k*bYIe&VsW`gBh39Dj)sBc`F}Zr U>iK^;g6jEyIfCl>e>sBx1MV!&m;e9( literal 0 HcmV?d00001 diff --git a/frontend/app/chat/page.tsx b/frontend/app/chat/page.tsx index 426d983..e9f68b5 100644 --- a/frontend/app/chat/page.tsx +++ b/frontend/app/chat/page.tsx @@ -1,38 +1,82 @@ // File: frontend/app/chat/page.tsx (更新以使用 API) // Description: 对接后端 API 实现助手和会话的加载与管理 -'use client'; +"use client"; -import React, { useState, useRef, useEffect, useCallback } from 'react'; -import { SendHorizontal, Loader2, PanelRightOpen, PanelRightClose, UserPlus, Settings2, Trash2, Edit, RefreshCw } from 'lucide-react'; // 添加刷新图标 +import React, { useState, useRef, useEffect, useCallback } from "react"; +import { + SendHorizontal, + Loader2, + PanelRightOpen, + PanelRightClose, + UserPlus, + Settings2, + Trash2, + Edit, + RefreshCw, +} from "lucide-react"; // 添加刷新图标 import { useForm } from "react-hook-form"; import { zodResolver } from "@hookform/resolvers/zod"; import * as z from "zod"; // Shadcn UI Components import { Button } from "@/components/ui/button"; -import { Dialog, DialogContent, DialogDescription, DialogHeader, DialogTitle, DialogTrigger, DialogFooter, DialogClose } from "@/components/ui/dialog"; -import { Form, FormControl, FormDescription, FormField, FormItem, FormLabel, FormMessage } from "@/components/ui/form"; +import { + Dialog, + DialogContent, + DialogDescription, + DialogHeader, + DialogTitle, + DialogTrigger, + DialogFooter, + DialogClose, +} from "@/components/ui/dialog"; +import { + Form, + FormControl, + FormDescription, + FormField, + FormItem, + FormLabel, + FormMessage, +} from "@/components/ui/form"; import { Input } from "@/components/ui/input"; import { Textarea } from "@/components/ui/textarea"; -import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from "@/components/ui/select"; +import { + Select, + SelectContent, + SelectItem, + SelectTrigger, + SelectValue, +} from "@/components/ui/select"; import { Slider } from "@/components/ui/slider"; import { Toaster, toast } from "sonner"; import { Skeleton } from "@/components/ui/skeleton"; // 导入骨架屏 // API 函数和类型 import { - sendChatMessage, getAssistants, createAssistant, updateAssistant, deleteAssistant, - getSessionsByAssistant, deleteSession, - Assistant, Session, AssistantCreateData, AssistantUpdateData, ChatApiResponse -} from '@/lib/api'; // 确保路径正确 + sendChatMessage, + getAssistants, + createAssistant, + updateAssistant, + deleteAssistant, + getSessionsByAssistant, + deleteSession, + getMessagesBySession, + Session, + ChatApiResponse, + Message as ApiMessage, +} from "@/lib/api"; // 确保路径正确 +import { + Assistant, + AssistantCreateData, + AssistantUpdateData, +} from "@/types/assistant"; -// --- 数据接口定义 --- -interface Message { - id: string; - text: string; - sender: 'user' | 'ai'; - isError?: boolean; +// --- Frontend specific Message type (includes optional isError) --- +interface Message extends ApiMessage { + // Extend the type from API + isError?: boolean; // Optional flag for frontend error styling } interface ChatSession { @@ -45,10 +89,16 @@ interface ChatSession { // --- Zod Schema for Assistant Form Validation --- const assistantFormSchema = z.object({ - name: z.string().min(1, { message: "助手名称不能为空" }).max(50, { message: "名称过长" }), + name: z + .string() + .min(1, { message: "助手名称不能为空" }) + .max(50, { message: "名称过长" }), description: z.string().max(200, { message: "描述过长" }).optional(), avatar: z.string().max(5, { message: "头像/Emoji 过长" }).optional(), // 简单限制长度 - system_prompt: z.string().min(1, { message: "系统提示不能为空" }).max(4000, { message: "系统提示过长" }), + system_prompt: z + .string() + .min(1, { message: "系统提示不能为空" }) + .max(4000, { message: "系统提示过长" }), model: z.string({ required_error: "请选择一个模型" }), temperature: z.number().min(0).max(1), }); @@ -57,19 +107,22 @@ type AssistantFormData = z.infer; // 可选的模型列表 const availableModels = [ - { value: "gpt-3.5-turbo", label: "GPT-3.5 Turbo" }, - { value: "gpt-4", label: "GPT-4" }, - { value: "gpt-4-turbo", label: "GPT-4 Turbo" }, - { value: "gemini-2.0-flash", label: "Gemini 2.0 Flash" }, - { value: "deepseek-coder", label: "DeepSeek Coder" }, // 示例 - // 添加更多模型... + { value: "gpt-3.5-turbo", label: "GPT-3.5 Turbo" }, + { value: "gpt-4", label: "GPT-4" }, + { value: "gpt-4-turbo", label: "GPT-4 Turbo" }, + { value: "gemini-2.0-flash", label: "Gemini 2.0 Flash" }, + { value: "deepseek-coder", label: "DeepSeek Coder" }, // 示例 + // 添加更多模型... ]; // --- Helper Function --- -const findLastSession = (sessions: ChatSession[], assistantId: string): ChatSession | undefined => { - return sessions - .filter(s => s.assistantId === assistantId && !s.isTemporary) - .sort((a, b) => b.createdAt.getTime() - a.createdAt.getTime())[0]; +const findLastSession = ( + sessions: ChatSession[], + assistantId: string +): ChatSession | undefined => { + return sessions + .filter((s) => s.assistantId === assistantId && !s.isTemporary) + .sort((a, b) => b.createdAt.getTime() - a.createdAt.getTime())[0]; }; // --- Assistant Form Component --- @@ -182,7 +235,7 @@ function AssistantForm({ assistant, onSave, onClose }: AssistantFormProps) { - {availableModels.map(model => ( + {availableModels.map((model) => ( {model.label} @@ -199,7 +252,9 @@ function AssistantForm({ assistant, onSave, onClose }: AssistantFormProps) { name="temperature" render={({ field }) => ( - 温度 (Temperature): {field.value.toFixed(1)} + + 温度 (Temperature): {field.value.toFixed(1)} + {/* Shadcn Slider expects an array for value */} - - 值越低越稳定,越高越有创造性。 - + 值越低越稳定,越高越有创造性。 )} /> - - - - + + ); } - // --- Main Chat Page Component --- export default function ChatPage() { // --- State Variables --- - const [inputMessage, setInputMessage] = useState(''); + const [inputMessage, setInputMessage] = useState(""); + // Messages state now holds Message type from API const [messages, setMessages] = useState([]); const [isLoading, setIsLoading] = useState(false); // AI 回复加载状态 @@ -244,15 +301,19 @@ export default function ChatPage() { const [isSessionPanelOpen, setIsSessionPanelOpen] = useState(true); const [isAssistantDialogOpen, setIsAssistantDialogOpen] = useState(false); // 控制助手表单 Dialog 显隐 - const [editingAssistant, setEditingAssistant] = useState(null); // 当前正在编辑的助手 + const [editingAssistant, setEditingAssistant] = useState( + null + ); // 当前正在编辑的助手 - // Data Loading States + // Data Loading States const [assistantsLoading, setAssistantsLoading] = useState(true); const [sessionsLoading, setSessionsLoading] = useState(false); - + const [messagesLoading, setMessagesLoading] = useState(false); // Data State const [assistants, setAssistants] = useState([]); - const [currentAssistantId, setCurrentAssistantId] = useState(null); // 初始为 null + const [currentAssistantId, setCurrentAssistantId] = useState( + null + ); // 初始为 null const [allSessions, setAllSessions] = useState([]); const [currentSessionId, setCurrentSessionId] = useState(null); // 初始为 null @@ -261,19 +322,20 @@ export default function ChatPage() { // --- Effects --- // Initial data loading (Assistants) + // Initial Assistant loading useEffect(() => { const loadAssistants = async () => { setAssistantsLoading(true); try { const fetchedAssistants = await getAssistants(); setAssistants(fetchedAssistants); - // 设置默认选中的助手 (例如第一个或 ID 为 'asst-default' 的) - const defaultAssistant = fetchedAssistants.find(a => a.id === 'asst-default') || fetchedAssistants[0]; + const defaultAssistant = + fetchedAssistants.find((a) => a.id === "asst-default") || + fetchedAssistants[0]; if (defaultAssistant) { setCurrentAssistantId(defaultAssistant.id); } else { - console.warn("No default or initial assistant found."); - // 可能需要提示用户创建助手 + console.warn("No default or initial assistant found."); } } catch (apiError: any) { toast.error(`加载助手列表失败: ${apiError.message}`); @@ -283,84 +345,109 @@ export default function ChatPage() { } }; loadAssistants(); - }, []); // 空依赖数组,只在挂载时运行一次 + }, []); + // Load sessions when assistant changes (remains same, but calls handleSelectSession internally) + useEffect(() => { + if (!currentAssistantId) return; + const loadSessions = async () => { + setSessionsLoading(true); + setCurrentSessionId(null); + setMessages([]); + try { + const fetchedSessions = await getSessionsByAssistant(currentAssistantId); + // Filter out sessions that might belong to a deleted assistant still in cache + const validAssistants = new Set(assistants.map(a => a.id)); + setAllSessions(prev => [ + ...prev.filter(s => s.assistant_id !== currentAssistantId && validAssistants.has(s.assistant_id)), + ...fetchedSessions + ]); + const lastSession = fetchedSessions.sort((a, b) => new Date(b.created_at).getTime() - new Date(a.created_at).getTime())[0]; + if (lastSession) { + setCurrentSessionId(lastSession.id); // Trigger message loading effect + } else { + setCurrentSessionId('temp-new-chat'); + const currentAssistant = assistants.find(a => a.id === currentAssistantId); + setMessages([{ id: `init-temp-${currentAssistantId}`, session_id: 'temp-new-chat', sender: 'ai', text: `开始与 ${currentAssistant?.name || '助手'} 的新对话吧!`, order: 0, created_at: new Date().toISOString() }]); + } + } catch (apiError: any) { toast.error(`加载会话列表失败: ${apiError.message}`); } + finally { setSessionsLoading(false); } + }; + loadSessions(); + }, [currentAssistantId]); // 空依赖数组,只在挂载时运行一次 // Load sessions when assistant changes useEffect(() => { - if (!currentAssistantId) return; // 如果没有选中助手,则不加载 - - const loadSessions = async () => { - setSessionsLoading(true); - // 清空当前会话和消息列表 - setCurrentSessionId(null); - setMessages([]); - try { - const fetchedSessions = await getSessionsByAssistant(currentAssistantId); - // 更新全局会话列表 (只保留其他助手的会话,加上当前助手的) - setAllSessions(prev => [ - ...prev.filter(s => s.assistant_id !== currentAssistantId), - ...fetchedSessions + if (!currentSessionId || currentSessionId === "temp-new-chat") { + // If it's temp-new-chat, messages are already set or should be empty initially + if (currentSessionId === "temp-new-chat" && messages.length === 0) { + // Ensure initial message is set + const currentAssistant = assistants.find( + (a) => a.id === currentAssistantId + ); + setMessages([ + { + id: `init-temp-${currentAssistantId}`, + session_id: "temp-new-chat", + sender: "ai", + text: `开始与 ${currentAssistant?.name || "助手"} 的新对话吧!`, + order: 0, + created_at: new Date().toISOString(), + }, ]); + } + return; + } - // 查找最新的会话并设为当前 - const lastSession = fetchedSessions - .sort((a, b) => new Date(b.created_at).getTime() - new Date(a.created_at).getTime())[0]; - - if (lastSession) { - setCurrentSessionId(lastSession.id); - // TODO: 加载 lastSession.id 的历史消息 - console.log(`加载助手 ${currentAssistantId} 的最后一个会话: ${lastSession.id}`); - const currentAssistant = assistants.find(a => a.id === currentAssistantId); - setMessages([ { id: `init-${lastSession.id}-1`, text: `继续与 ${currentAssistant?.name || '助手'} 的对话: ${lastSession.title}`, sender: 'ai' } ]); - } else { - // 没有历史会话,进入临时新对话状态 - setCurrentSessionId('temp-new-chat'); - console.log(`助手 ${currentAssistantId} 没有历史会话,创建临时新对话`); - const currentAssistant = assistants.find(a => a.id === currentAssistantId); - setMessages([ { id: `init-temp-${currentAssistantId}`, text: `开始与 ${currentAssistant?.name || '助手'} 的新对话吧!`, sender: 'ai' } ]); - } + const loadMessages = async () => { + setMessagesLoading(true); + setError(null); // Clear previous errors + console.log(`加载会话 ${currentSessionId} 的消息...`); + try { + const fetchedMessages = await getMessagesBySession(currentSessionId); + setMessages(fetchedMessages); + console.log(`成功加载 ${fetchedMessages.length} 条消息`); } catch (apiError: any) { - toast.error(`加载会话列表失败: ${apiError.message}`); + toast.error(`加载消息失败: ${apiError.message}`); + setError(`无法加载消息: ${apiError.message}`); + setMessages([]); // Clear messages on error } finally { - setSessionsLoading(false); + setMessagesLoading(false); } }; + loadMessages(); + }, [currentSessionId]); // 依赖助手 ID 和助手列表 (以防助手信息更新) - loadSessions(); - }, [currentAssistantId, assistants]); // 依赖助手 ID 和助手列表 (以防助手信息更新) - // Auto scroll useEffect(() => { - messagesEndRef.current?.scrollIntoView({ behavior: 'smooth' }); + messagesEndRef.current?.scrollIntoView({ behavior: "smooth" }); }, [messages]); - + // Filter sessions for the current assistant (UI display) const currentAssistantSessions = React.useMemo(() => { - // 直接从 allSessions 过滤,因为加载时已经更新了 - return allSessions - .filter(s => s.assistant_id === currentAssistantId) - .sort((a, b) => new Date(b.created_at).getTime() - new Date(a.created_at).getTime()); // 按时间倒序 + // 直接从 allSessions 过滤,因为加载时已经更新了 + return allSessions + .filter((s) => s.assistant_id === currentAssistantId) + .sort( + (a, b) => + new Date(b.created_at).getTime() - new Date(a.created_at).getTime() + ); // 按时间倒序 }, [allSessions, currentAssistantId]); - // --- Assistant CRUD Handlers (Updated with API calls) --- + // --- Assistant CRUD Handlers (Updated with API calls) --- const handleSaveAssistant = async (data: AssistantFormData, id?: string) => { try { let savedAssistant: Assistant; if (id) { // 编辑 savedAssistant = await updateAssistant(id, data); - setAssistants(prev => prev.map(a => (a.id === id ? savedAssistant : a))); + setAssistants((prev) => + prev.map((a) => (a.id === id ? savedAssistant : a)) + ); toast.success(`助手 "${savedAssistant.name}" 已更新`); - // 如果更新的是当前助手,可能需要重新加载会话或消息 - if (id === currentAssistantId) { - // 简单处理:可以强制刷新会话列表(或提示用户) - setCurrentAssistantId(null); // 触发 useEffect 重新加载 - setTimeout(() => setCurrentAssistantId(id), 0); - } } else { // 创建 savedAssistant = await createAssistant(data); - setAssistants(prev => [...prev, savedAssistant]); + setAssistants((prev) => [...prev, savedAssistant]); toast.success(`助手 "${savedAssistant.name}" 已创建`); // 创建后自动选中 handleSelectAssistant(savedAssistant.id); @@ -371,8 +458,8 @@ export default function ChatPage() { } }; - const handleDeleteAssistant = async (idToDelete: string) => { - if (idToDelete === 'asst-default' || assistants.length <= 1) { + const handleDeleteAssistant = async (idToDelete: string) => { + if (idToDelete === "asst-default" || assistants.length <= 1) { toast.error("不能删除默认助手或最后一个助手"); return; } @@ -381,12 +468,16 @@ export default function ChatPage() { return; } - const assistantToDelete = assistants.find(a => a.id === idToDelete); - if (window.confirm(`确定要删除助手 "${assistantToDelete?.name}" 吗?相关会话也将被删除。`)) { + const assistantToDelete = assistants.find((a) => a.id === idToDelete); + if ( + window.confirm( + `确定要删除助手 "${assistantToDelete?.name}" 吗?相关会话也将被删除。` + ) + ) { try { await deleteAssistant(idToDelete); // 后端应负责删除关联会话,前端只需更新助手列表 - setAssistants(prev => prev.filter(a => a.id !== idToDelete)); + setAssistants((prev) => prev.filter((a) => a.id !== idToDelete)); // (可选) 如果需要立即清除前端的会话缓存 // setAllSessions(prev => prev.filter(s => s.assistant_id !== idToDelete)); toast.success(`助手 "${assistantToDelete?.name}" 已删除`); @@ -397,86 +488,124 @@ export default function ChatPage() { }; const handleEditAssistant = (assistant: Assistant) => { - setEditingAssistant(assistant); - setIsAssistantDialogOpen(true); + setEditingAssistant(assistant); + setIsAssistantDialogOpen(true); }; const handleOpenCreateAssistantDialog = () => { - setEditingAssistant(null); - setIsAssistantDialogOpen(true); + setEditingAssistant(null); + setIsAssistantDialogOpen(true); }; - // --- Send Message Handler (Updated with API response handling) --- + // --- Send Message Handler (Updated - handles new session ID from response) --- const handleSendMessage = async (e?: React.FormEvent) => { e?.preventDefault(); const trimmedMessage = inputMessage.trim(); - if (!trimmedMessage || isLoading || !currentSessionId || !currentAssistantId) return; // 增加检查 + if ( + !trimmedMessage || + isLoading || + !currentSessionId || + !currentAssistantId + ) + return; setError(null); - setIsLoading(true); + setIsLoading(true); // Start loading (for AI reply) - const userMessage: Message = { - id: Date.now().toString(), + const tempUserMessageId = `temp-user-${Date.now()}`; // Temporary ID for optimistic update + const userMessageOptimistic: Message = { + id: tempUserMessageId, + session_id: + currentSessionId === "temp-new-chat" ? "temp" : currentSessionId, // Use temp session id if needed text: trimmedMessage, - sender: 'user', + sender: "user", + order: (messages[messages.length - 1]?.order || 0) + 1, // Estimate order + created_at: new Date().toISOString(), }; - // 立即显示用户消息 - setMessages(prev => [...prev, userMessage]); - setInputMessage(''); // 清空输入框 + + // Optimistic UI update: Add user message immediately + setMessages((prev) => [...prev, userMessageOptimistic]); + setInputMessage(""); + + let targetSessionId = currentSessionId; // Will be updated if new session is created try { - // 调用后端 API const response: ChatApiResponse = await sendChatMessage( - trimmedMessage, - currentSessionId, // 发送当前 session ID ('temp-new-chat' 或真实 ID) - currentAssistantId + trimmedMessage, + currentSessionId, // Send 'temp-new-chat' or actual ID + currentAssistantId ); - // 处理 AI 回复 + // Process successful response const aiMessage: Message = { - id: Date.now().toString() + '_ai', + id: `ai-${Date.now()}`, // Use temporary or actual ID from backend if provided + session_id: response.session_id || targetSessionId, // Use new session ID if available text: response.reply, - sender: 'ai', + sender: "ai", + order: userMessageOptimistic.order + 1, // Estimate order + created_at: new Date().toISOString(), }; - setMessages((prevMessages) => [...prevMessages, aiMessage]); - // 如果后端创建了新会话并返回了信息 - if (response.session_id && response.session_title && currentSessionId === 'temp-new-chat') { + // Update messages: Replace temp user message with potential real one (if backend returned it) + // and add AI message. For simplicity, we just add the AI message. + // A more robust solution would involve matching IDs. + setMessages((prev) => [ + ...prev.filter((m) => m.id !== tempUserMessageId), + userMessageOptimistic, + aiMessage, + ]); // Keep optimistic user msg for now + + // If a new session was created by the backend + if ( + response.session_id && + response.session_title && + currentSessionId === "temp-new-chat" + ) { const newSession: Session = { id: response.session_id, title: response.session_title, assistant_id: currentAssistantId, - created_at: new Date().toISOString(), // 使用客户端时间或后端返回的时间 + created_at: new Date().toISOString(), // Or use time from backend if available }; - // 更新全局会话列表和当前会话 ID - setAllSessions(prev => [...prev, newSession]); - setCurrentSessionId(newSession.id); - console.log(`前端已更新新会话信息: ID=${newSession.id}, Title=${newSession.title}`); + setAllSessions((prev) => [...prev, newSession]); + setCurrentSessionId(newSession.id); // Switch to the new session ID + // Update the session_id of the messages just added + setMessages((prev) => + prev.map((m) => + m.session_id === "temp" ? { ...m, session_id: newSession.id } : m + ) + ); + console.log( + `前端已更新新会话信息: ID=${newSession.id}, Title=${newSession.title}` + ); } - } catch (apiError: any) { - console.error("发送消息失败:", apiError); - const errorMessageText = apiError.message || '发生未知错误'; - setError(errorMessageText); + // Handle error: Remove optimistic user message and show error + setMessages((prev) => prev.filter((m) => m.id !== tempUserMessageId)); + const errorMessageText = apiError.message || "发生未知错误"; + toast.error(`发送消息失败: ${errorMessageText}`); + setError(`发送消息失败: ${errorMessageText}`); + // Optionally add an error message to the chat const errorMessage: Message = { - id: Date.now().toString() + '_err', + /* ... */ id: `err-${Date.now()}`, + session_id: targetSessionId, text: `错误: ${errorMessageText}`, - sender: 'ai', - isError: true, + sender: "ai", + order: userMessageOptimistic.order + 1, + created_at: new Date().toISOString(), }; setMessages((prevMessages) => [...prevMessages, errorMessage]); } finally { - setIsLoading(false); + setIsLoading(false); // Stop AI reply loading } }; - // --- Other Handlers (基本不变, 但需要检查 currentAssistantId/currentSessionId 是否存在) --- - const handleInputChange = (e: React.ChangeEvent) => { + const handleInputChange = (e: React.ChangeEvent) => { setInputMessage(e.target.value); }; const handleKeyDown = (e: React.KeyboardEvent) => { - if (e.key === 'Enter' && !e.shiftKey && !isLoading) { + if (e.key === "Enter" && !e.shiftKey && !isLoading) { e.preventDefault(); handleSendMessage(); } @@ -487,203 +616,286 @@ export default function ChatPage() { }; const handleSelectAssistant = (assistantId: string) => { - if (assistantId !== currentAssistantId) { - setCurrentAssistantId(assistantId); // 触发 useEffect 加载会话 - } - } + console.log("选择助手id",assistantId) + if (assistantId !== currentAssistantId) { + setCurrentAssistantId(assistantId); // 触发 useEffect 加载会话 + console.log("当前助手id",currentAssistantId) + } + }; + // Updated handleSelectSession to just set the ID, useEffect handles loading const handleSelectSession = (sessionId: string) => { - if (sessionId !== currentSessionId) { - setCurrentSessionId(sessionId); - // TODO: 调用 API 加载该会话的历史消息 - console.log(`切换到会话: ${sessionId}`); - const session = allSessions.find(s => s.id === sessionId); - const assistant = assistants.find(a => a.id === session?.assistant_id); - setMessages([ - { id: `init-${sessionId}-1`, text: `继续与 ${assistant?.name || '助手'} 的对话: ${session?.title || ''}`, sender: 'ai' }, - // ... 加载真实历史消息 - ]); - } - } + if (sessionId !== currentSessionId) { + setCurrentSessionId(sessionId); // Trigger useEffect to load messages + } + }; const handleNewTopic = () => { - if (currentSessionId !== 'temp-new-chat' && currentAssistantId) { // 确保有助手被选中 - setCurrentSessionId('temp-new-chat'); - const currentAssistant = assistants.find(a => a.id === currentAssistantId); - setMessages([ - { id: `init-temp-${currentAssistantId}`, text: `开始与 ${currentAssistant?.name || '助手'} 的新对话吧!`, sender: 'ai' }, - ]); - console.log("手动创建临时新对话"); - } - } + if (currentSessionId !== "temp-new-chat" && currentAssistantId) { + // 确保有助手被选中 + setCurrentSessionId("temp-new-chat"); + console.log("手动创建临时新对话"); + } + }; // --- JSX Rendering --- return ( // 最外层 Flex 容器 -
{/* 使用 gap 添加间距 */} +
+ {" "} + {/* 使用 gap 添加间距 */} {/* 左侧助手面板 */} - {/* 中间主聊天区域 */}
{/* 聊天窗口标题 - 显示当前助手和切换会话按钮 */}
- {currentAssistantId ? ( -
- {assistants.find(a => a.id === currentAssistantId)?.avatar || '👤'} -

- {assistants.find(a => a.id === currentAssistantId)?.name || '加载中...'} - - ({currentSessionId === 'temp-new-chat' ? '新话题' : allSessions.find(s => s.id === currentSessionId)?.title || (sessionsLoading ? '加载中...' : '选择话题')}) - -

-
- ) : ( - // 助手加载中显示骨架屏 - )} - + {currentAssistantId ? ( +
+ + {assistants.find((a) => a.id === currentAssistantId)?.avatar || + "👤"} + +

+ {assistants.find((a) => a.id === currentAssistantId)?.name || + "加载中..."} + + ( + {currentSessionId === "temp-new-chat" + ? "新话题" + : allSessions.find((s) => s.id === currentSessionId) + ?.title || (sessionsLoading ? "加载中..." : "选择话题")} + ) + +

+
+ ) : ( + // 助手加载中显示骨架屏 + )} +
{/* 消息显示区域 */}
{/* 可以添加一个全局错误提示 */} - {error &&

{error}

} - {messages.length === 0 && !isLoading && !sessionsLoading && currentSessionId !== 'temp-new-chat' && ( -

选择一个话题开始聊天,或新建一个话题。

+ {error && ( +

+ {error} +

)} - {messages.map((message) => ( -
+ {messagesLoading ? ( + // Message loading skeleton +
+ + + +
+ ) : messages.length === 0 && currentSessionId !== "temp-new-chat" ? ( +

+ {currentAssistantId + ? "选择或新建一个话题开始聊天。" + : "请先选择一个助手。"} +

+ ) : ( + // Render actual messages + messages.map((message) => (
-

{message.text}

+
+

{message.text}

+
-
- ))} + )) + )} {isLoading && (
- AI 正在思考... + + AI 正在思考... +
)}
@@ -691,76 +903,112 @@ export default function ChatPage() { {/* 消息输入区域 */}
-
+
- {/* 右侧会话管理面板 */} -
); } diff --git a/frontend/app/layout.tsx b/frontend/app/layout.tsx index ee55954..0c1c472 100644 --- a/frontend/app/layout.tsx +++ b/frontend/app/layout.tsx @@ -45,7 +45,7 @@ export default function RootLayout({ {/* 侧边栏导航 */} -