添加数据管理
This commit is contained in:
parent
7df10e82be
commit
f0863914c2
@ -1,12 +1,12 @@
|
||||
# File: backend/app/api/v1/api.py (更新)
|
||||
# File: backend/app/api/v1/api.py (Update)
|
||||
# Description: 聚合 v1 版本的所有 API 路由
|
||||
|
||||
from fastapi import APIRouter
|
||||
from app.api.v1.endpoints import chat, assistants, sessions # 导入新路由
|
||||
from app.api.v1.endpoints import chat, assistants, sessions, messages # Import messages router
|
||||
|
||||
api_router = APIRouter()
|
||||
|
||||
api_router.include_router(chat.router, prefix="/chat", tags=["Chat"])
|
||||
api_router.include_router(assistants.router, prefix="/assistants", tags=["Assistants"]) # 添加助手路由
|
||||
api_router.include_router(sessions.router, prefix="/sessions", tags=["Sessions"]) # 添加会话路由
|
||||
|
||||
api_router.include_router(assistants.router, prefix="/assistants", tags=["Assistants"])
|
||||
api_router.include_router(sessions.router, prefix="/sessions", tags=["Sessions"])
|
||||
api_router.include_router(messages.router, prefix="/messages", tags=["Messages"]) # Add messages router
|
||||
|
||||
@ -1,39 +1,39 @@
|
||||
# File: backend/app/api/v1/endpoints/assistants.py (新建)
|
||||
# Description: 助手的 API 路由
|
||||
# File: backend/app/api/v1/endpoints/assistants.py (Update with DB session dependency)
|
||||
# Description: 助手的 API 路由 (使用数据库会话)
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Depends, status
|
||||
from typing import List
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from app.db.database import get_db_session # Import DB session dependency
|
||||
from app.models.pydantic_models import AssistantRead, AssistantCreate, AssistantUpdate
|
||||
from app.services.assistant_service import assistant_service_instance, AssistantService
|
||||
from app.services.assistant_service import AssistantService # Import the class
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# --- 依赖注入 AssistantService ---
|
||||
def get_assistant_service() -> AssistantService:
|
||||
return assistant_service_instance
|
||||
# --- Dependency Injection for Service and DB Session ---
|
||||
# Service instance can be created per request or globally
|
||||
# For simplicity, let's create it here, but pass db session to methods
|
||||
assistant_service = AssistantService()
|
||||
|
||||
@router.post("/", response_model=AssistantRead, status_code=status.HTTP_201_CREATED)
|
||||
async def create_new_assistant(
|
||||
assistant_data: AssistantCreate,
|
||||
service: AssistantService = Depends(get_assistant_service)
|
||||
db: AsyncSession = Depends(get_db_session) # Inject DB session
|
||||
):
|
||||
"""创建新助手"""
|
||||
return service.create_assistant(assistant_data)
|
||||
return await assistant_service.create_assistant(db, assistant_data)
|
||||
|
||||
@router.get("/", response_model=List[AssistantRead])
|
||||
async def read_all_assistants(
|
||||
service: AssistantService = Depends(get_assistant_service)
|
||||
db: AsyncSession = Depends(get_db_session)
|
||||
):
|
||||
"""获取所有助手列表"""
|
||||
return service.get_assistants()
|
||||
return await assistant_service.get_assistants(db)
|
||||
|
||||
@router.get("/{assistant_id}", response_model=AssistantRead)
|
||||
async def read_assistant_by_id(
|
||||
assistant_id: str,
|
||||
service: AssistantService = Depends(get_assistant_service)
|
||||
db: AsyncSession = Depends(get_db_session)
|
||||
):
|
||||
"""根据 ID 获取特定助手"""
|
||||
assistant = service.get_assistant(assistant_id)
|
||||
assistant = await assistant_service.get_assistant(db, assistant_id)
|
||||
if not assistant:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="找不到指定的助手")
|
||||
return assistant
|
||||
@ -42,10 +42,9 @@ async def read_assistant_by_id(
|
||||
async def update_existing_assistant(
|
||||
assistant_id: str,
|
||||
assistant_data: AssistantUpdate,
|
||||
service: AssistantService = Depends(get_assistant_service)
|
||||
db: AsyncSession = Depends(get_db_session)
|
||||
):
|
||||
"""更新指定 ID 的助手"""
|
||||
updated_assistant = service.update_assistant(assistant_id, assistant_data)
|
||||
updated_assistant = await assistant_service.update_assistant(db, assistant_id, assistant_data)
|
||||
if not updated_assistant:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="找不到指定的助手")
|
||||
return updated_assistant
|
||||
@ -53,14 +52,17 @@ async def update_existing_assistant(
|
||||
@router.delete("/{assistant_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_existing_assistant(
|
||||
assistant_id: str,
|
||||
service: AssistantService = Depends(get_assistant_service)
|
||||
db: AsyncSession = Depends(get_db_session)
|
||||
):
|
||||
"""删除指定 ID 的助手"""
|
||||
deleted = service.delete_assistant(assistant_id)
|
||||
# Handle potential error from service if trying to delete default
|
||||
try:
|
||||
deleted = await assistant_service.delete_assistant(db, assistant_id)
|
||||
if not deleted:
|
||||
# 根据服务层逻辑判断是找不到还是不允许删除
|
||||
assistant = service.get_assistant(assistant_id)
|
||||
# Check if it exists to differentiate 404 from 403 (or handle in service)
|
||||
assistant = await assistant_service.get_assistant(db, assistant_id)
|
||||
if assistant and assistant_id == 'asst-default':
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="不允许删除默认助手")
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="找不到指定的助手")
|
||||
# 成功删除,不返回内容
|
||||
except Exception as e: # Catch other potential DB errors
|
||||
print(f"删除助手时出错: {e}")
|
||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="删除助手失败")
|
||||
|
||||
@ -1,27 +1,26 @@
|
||||
# File: backend/app/api/v1/endpoints/chat.py (更新)
|
||||
# Description: 聊天功能的 API 路由 (使用更新后的 ChatService)
|
||||
# File: backend/app/api/v1/endpoints/chat.py (Update with DB session dependency)
|
||||
# Description: 聊天功能的 API 路由 (使用数据库会话)
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Depends, status
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from app.db.database import get_db_session
|
||||
from app.models.pydantic_models import ChatRequest, ChatResponse, SessionCreateRequest
|
||||
from app.services.chat_service import chat_service_instance, ChatService
|
||||
from app.services.session_service import session_service_instance, SessionService # 导入 SessionService
|
||||
from app.services.chat_service import ChatService # Import class
|
||||
from app.services.session_service import SessionService # Import class
|
||||
import app.core.config as Config # Import API Key for ChatService instantiation
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# --- 依赖注入 ---
|
||||
def get_chat_service() -> ChatService:
|
||||
return chat_service_instance
|
||||
|
||||
def get_session_service() -> SessionService:
|
||||
return session_service_instance
|
||||
# --- Dependency Injection ---
|
||||
# Instantiate services here or use a more sophisticated dependency injection system
|
||||
chat_service = ChatService(default_api_key=Config.GOOGLE_API_KEY)
|
||||
session_service = SessionService()
|
||||
|
||||
@router.post("/", response_model=ChatResponse)
|
||||
async def handle_chat_message(
|
||||
request: ChatRequest,
|
||||
chat_service: ChatService = Depends(get_chat_service),
|
||||
session_service: SessionService = Depends(get_session_service) # 注入 SessionService
|
||||
db: AsyncSession = Depends(get_db_session) # Inject DB session
|
||||
):
|
||||
"""处理用户发送的聊天消息 (包含 assistantId 和 sessionId)"""
|
||||
user_message = request.message
|
||||
session_id = request.session_id
|
||||
assistant_id = request.assistant_id
|
||||
@ -31,38 +30,39 @@ async def handle_chat_message(
|
||||
response_session_id = None
|
||||
response_session_title = None
|
||||
|
||||
# --- 处理临时新会话 ---
|
||||
if session_id == 'temp-new-chat':
|
||||
print("检测到临时新会话,正在创建...")
|
||||
try:
|
||||
# 调用 SessionService 创建会话
|
||||
create_req = SessionCreateRequest(assistant_id=assistant_id, first_message=user_message)
|
||||
created_session = await session_service.create_session(create_req)
|
||||
session_id = created_session.id # 使用新创建的会话 ID
|
||||
response_session_id = created_session.id # 准备在响应中返回新 ID
|
||||
response_session_title = created_session.title # 准备在响应中返回新标题
|
||||
# Pass db session to the service method
|
||||
created_session = await session_service.create_session(db, create_req)
|
||||
session_id = created_session.id
|
||||
response_session_id = created_session.id
|
||||
response_session_title = created_session.title
|
||||
print(f"新会话已创建: ID={session_id}, Title='{created_session.title}'")
|
||||
except ValueError as e: # 助手不存在等错误
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||
except Exception as e: # LLM 调用或其他错误
|
||||
except Exception as e:
|
||||
print(f"创建会话时出错: {e}")
|
||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="创建会话失败")
|
||||
|
||||
# --- 调用聊天服务获取回复 ---
|
||||
try:
|
||||
# Pass db session to the service method
|
||||
ai_reply = await chat_service.get_ai_reply(
|
||||
db=db,
|
||||
user_message=user_message,
|
||||
session_id=session_id, # 使用真实的或新创建的 session_id
|
||||
session_id=session_id,
|
||||
assistant_id=assistant_id
|
||||
)
|
||||
print(f"发送 AI 回复: '{ai_reply}'")
|
||||
return ChatResponse(
|
||||
reply=ai_reply,
|
||||
session_id=response_session_id, # 返回新 ID (如果创建了)
|
||||
session_title=response_session_title # 返回新标题 (如果创建了)
|
||||
session_id=response_session_id,
|
||||
session_title=response_session_title
|
||||
)
|
||||
except ValueError as e: # 助手不存在等错误
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||
except Exception as e: # LLM 调用或其他错误
|
||||
except Exception as e:
|
||||
print(f"处理聊天消息时发生错误: {e}")
|
||||
# The get_db_session dependency will handle rollback
|
||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e))
|
||||
|
||||
33
backend/app/api/v1/endpoints/messages.py
Normal file
33
backend/app/api/v1/endpoints/messages.py
Normal file
@ -0,0 +1,33 @@
|
||||
# File: backend/app/api/v1/endpoints/messages.py (New)
|
||||
# Description: API endpoint for fetching messages
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
||||
from typing import List
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from app.db.database import get_db_session
|
||||
from app.models.pydantic_models import MessageRead
|
||||
from app.db.models import MessageModel # Import DB model
|
||||
from sqlalchemy.future import select
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/session/{session_id}", response_model=List[MessageRead])
|
||||
async def read_messages_for_session(
|
||||
session_id: str,
|
||||
db: AsyncSession = Depends(get_db_session),
|
||||
skip: int = Query(0, ge=0), # Offset for pagination
|
||||
limit: int = Query(100, ge=1, le=500) # Limit number of messages
|
||||
):
|
||||
"""获取指定会话的消息列表 (按时间顺序)"""
|
||||
# TODO: Add check if session exists
|
||||
stmt = (
|
||||
select(MessageModel)
|
||||
.filter(MessageModel.session_id == session_id)
|
||||
.order_by(MessageModel.order.asc()) # Fetch in chronological order
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
messages = result.scalars().all()
|
||||
# Validate using Pydantic model before returning
|
||||
return [MessageRead.model_validate(msg) for msg in messages]
|
||||
@ -1,47 +1,43 @@
|
||||
# File: backend/app/api/v1/endpoints/sessions.py (新建)
|
||||
# Description: 会话管理的 API 路由
|
||||
# File: backend/app/api/v1/endpoints/sessions.py (Update with DB session dependency)
|
||||
# Description: 会话管理的 API 路由 (使用数据库会话)
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Depends, status
|
||||
from typing import List
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from app.db.database import get_db_session
|
||||
from app.models.pydantic_models import SessionRead, SessionCreateRequest, SessionCreateResponse
|
||||
from app.services.session_service import session_service_instance, SessionService
|
||||
from app.services.session_service import SessionService # Import the class
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
def get_session_service() -> SessionService:
|
||||
return session_service_instance
|
||||
session_service = SessionService() # Create instance
|
||||
|
||||
@router.post("/", response_model=SessionCreateResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def create_new_session(
|
||||
session_data: SessionCreateRequest,
|
||||
service: SessionService = Depends(get_session_service)
|
||||
db: AsyncSession = Depends(get_db_session) # Inject DB session
|
||||
):
|
||||
"""创建新会话并自动生成标题"""
|
||||
try:
|
||||
return await service.create_session(session_data)
|
||||
return await session_service.create_session(db, session_data)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
|
||||
except Exception as e:
|
||||
# 处理可能的 LLM 调用错误
|
||||
print(f"创建会话时出错: {e}")
|
||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="创建会话失败")
|
||||
|
||||
@router.get("/assistant/{assistant_id}", response_model=List[SessionRead])
|
||||
async def read_sessions_for_assistant(
|
||||
assistant_id: str,
|
||||
service: SessionService = Depends(get_session_service)
|
||||
db: AsyncSession = Depends(get_db_session)
|
||||
):
|
||||
"""获取指定助手的所有会话列表"""
|
||||
# TODO: 添加检查助手是否存在
|
||||
return service.get_sessions_by_assistant(assistant_id)
|
||||
# Consider adding check if assistant exists first
|
||||
return await session_service.get_sessions_by_assistant(db, assistant_id)
|
||||
|
||||
@router.get("/{session_id}", response_model=SessionRead)
|
||||
async def read_session_by_id(
|
||||
session_id: str,
|
||||
service: SessionService = Depends(get_session_service)
|
||||
db: AsyncSession = Depends(get_db_session)
|
||||
):
|
||||
"""获取单个会话信息"""
|
||||
session = service.get_session(session_id)
|
||||
session = await session_service.get_session(db, session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="找不到指定的会话")
|
||||
return session
|
||||
@ -49,10 +45,8 @@ async def read_session_by_id(
|
||||
@router.delete("/{session_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_existing_session(
|
||||
session_id: str,
|
||||
service: SessionService = Depends(get_session_service)
|
||||
db: AsyncSession = Depends(get_db_session)
|
||||
):
|
||||
"""删除指定 ID 的会话"""
|
||||
deleted = service.delete_session(session_id)
|
||||
deleted = await session_service.delete_session(db, session_id)
|
||||
if not deleted:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="找不到指定的会话")
|
||||
|
||||
|
||||
@ -12,4 +12,6 @@ load_dotenv()
|
||||
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
||||
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY") # 如果使用 Google
|
||||
|
||||
# 可以在这里添加其他配置项
|
||||
# Define the database URL (SQLite in this case)
|
||||
# DATABASE_URL = "sqlite+aiosqlite:///./cherryai.db" # Use async driver
|
||||
DATABASE_URL = os.getenv("DATABASE_URL", "sqlite+aiosqlite:///./cherryai.db")
|
||||
43
backend/app/db/database.py
Normal file
43
backend/app/db/database.py
Normal file
@ -0,0 +1,43 @@
|
||||
# File: backend/app/db/database.py (New - Database setup)
|
||||
# Description: SQLAlchemy 数据库引擎和会话设置
|
||||
|
||||
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
|
||||
from sqlalchemy.orm import sessionmaker, declarative_base
|
||||
from app.core.config import DATABASE_URL
|
||||
|
||||
# 创建异步数据库引擎
|
||||
# connect_args={"check_same_thread": False} is needed only for SQLite.
|
||||
# It's not needed for other databases.
|
||||
engine = create_async_engine(DATABASE_URL, echo=True, connect_args={"check_same_thread": False})
|
||||
|
||||
# 创建异步会话工厂
|
||||
# expire_on_commit=False prevents attributes from expiring after commit.
|
||||
AsyncSessionFactory = sessionmaker(
|
||||
bind=engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
autocommit=False,
|
||||
autoflush=False,
|
||||
)
|
||||
|
||||
# 创建数据模型的基础类
|
||||
Base = declarative_base()
|
||||
|
||||
# --- Dependency to get DB session ---
|
||||
async def get_db_session() -> AsyncSession:
|
||||
"""FastAPI dependency to get an async database session."""
|
||||
async with AsyncSessionFactory() as session:
|
||||
try:
|
||||
yield session
|
||||
await session.commit() # Commit transaction if successful
|
||||
except Exception:
|
||||
await session.rollback() # Rollback on error
|
||||
raise
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
# --- Function to create tables (call this on startup) ---
|
||||
async def create_db_and_tables():
|
||||
async with engine.begin() as conn:
|
||||
# await conn.run_sync(Base.metadata.drop_all) # Use drop_all carefully in dev
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
52
backend/app/db/models.py
Normal file
52
backend/app/db/models.py
Normal file
@ -0,0 +1,52 @@
|
||||
# File: backend/app/db/models.py (New - Database models)
|
||||
# Description: SQLAlchemy ORM 模型定义
|
||||
|
||||
from sqlalchemy import Column, String, Float, ForeignKey, Text, DateTime, Integer
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.sql import func # For default timestamps
|
||||
from app.db.database import Base
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
def generate_uuid():
|
||||
return str(uuid.uuid4())
|
||||
|
||||
class AssistantModel(Base):
|
||||
__tablename__ = "assistants"
|
||||
|
||||
id = Column(String, primary_key=True, default=generate_uuid)
|
||||
name = Column(String(50), nullable=False, index=True)
|
||||
description = Column(String(200), nullable=True)
|
||||
avatar = Column(String(5), nullable=True)
|
||||
system_prompt = Column(Text, nullable=False)
|
||||
model = Column(String, nullable=False)
|
||||
temperature = Column(Float, nullable=False, default=0.7)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
|
||||
|
||||
sessions = relationship("SessionModel", back_populates="assistant", cascade="all, delete-orphan")
|
||||
|
||||
class SessionModel(Base):
|
||||
__tablename__ = "sessions"
|
||||
|
||||
id = Column(String, primary_key=True, default=generate_uuid)
|
||||
title = Column(String(100), nullable=False, default="New Chat")
|
||||
assistant_id = Column(String, ForeignKey("assistants.id"), nullable=False, index=True)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
updated_at = Column(DateTime(timezone=True), onupdate=func.now(), index=True) # Index for sorting
|
||||
|
||||
assistant = relationship("AssistantModel", back_populates="sessions")
|
||||
messages = relationship("MessageModel", back_populates="session", cascade="all, delete-orphan", order_by="MessageModel.created_at") # Order messages by time
|
||||
|
||||
class MessageModel(Base):
|
||||
__tablename__ = "messages"
|
||||
|
||||
id = Column(String, primary_key=True, default=generate_uuid)
|
||||
session_id = Column(String, ForeignKey("sessions.id"), nullable=False, index=True)
|
||||
sender = Column(String(10), nullable=False) # 'user' or 'ai' or 'system'
|
||||
text = Column(Text, nullable=False)
|
||||
order = Column(Integer, nullable=False) # Explicit order within session
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
session = relationship("SessionModel", back_populates="messages")
|
||||
|
||||
@ -1,20 +1,31 @@
|
||||
# File: backend/app/main.py (确认 load_dotenv 调用位置)
|
||||
# Description: FastAPI 应用入口
|
||||
# File: backend/app/main.py (Update - Add startup event)
|
||||
# Description: FastAPI 应用入口 (添加数据库初始化)
|
||||
|
||||
from fastapi import FastAPI
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
from app.api.v1.api import api_router as api_router_v1
|
||||
# 确保在创建 FastAPI 实例之前加载环境变量
|
||||
from app.core.config import OPENAI_API_KEY # 导入会触发 load_dotenv
|
||||
import app.core.config # Ensure config is loaded
|
||||
from app.db.database import create_db_and_tables # Import table creation function
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
# 创建 FastAPI 应用实例
|
||||
app = FastAPI(title="CherryAI Backend", version="0.1.0")
|
||||
# --- Lifespan context manager for startup/shutdown events ---
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
# Startup actions
|
||||
print("应用程序启动中...")
|
||||
await create_db_and_tables() # Create database tables on startup
|
||||
print("数据库表已检查/创建。")
|
||||
# You can add the default assistant creation here if needed,
|
||||
# but doing it in the service/model definition might be simpler for defaults.
|
||||
yield
|
||||
# Shutdown actions
|
||||
print("应用程序关闭中...")
|
||||
|
||||
# --- 配置 CORS ---
|
||||
origins = [
|
||||
"http://localhost:3000",
|
||||
"http://127.0.0.1:3000",
|
||||
]
|
||||
# Create FastAPI app with lifespan context manager
|
||||
app = FastAPI(title="CherryAI Backend", version="0.1.0", lifespan=lifespan)
|
||||
|
||||
# --- CORS Middleware ---
|
||||
origins = [ "http://localhost:3000", "http://127.0.0.1:3000" ]
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=origins,
|
||||
@ -23,10 +34,10 @@ app.add_middleware(
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# --- 挂载 API 路由 ---
|
||||
# --- API Routers ---
|
||||
app.include_router(api_router_v1, prefix="/api/v1")
|
||||
|
||||
# --- 根路径 ---
|
||||
# --- Root Endpoint ---
|
||||
@app.get("/", tags=["Root"])
|
||||
async def read_root():
|
||||
return {"message": "欢迎来到 CherryAI 后端!"}
|
||||
@ -1,9 +1,10 @@
|
||||
# File: backend/app/models/pydantic_models.py (更新)
|
||||
# File: backend/app/models/pydantic_models.py (Update Read models, add Message models)
|
||||
# Description: Pydantic 模型定义 API 数据结构
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional, List
|
||||
import uuid # 用于生成唯一 ID
|
||||
import uuid
|
||||
from datetime import datetime # Use datetime directly
|
||||
|
||||
# --- Assistant Models ---
|
||||
|
||||
@ -33,7 +34,8 @@ class AssistantUpdate(BaseModel):
|
||||
class AssistantRead(AssistantBase):
|
||||
"""读取助手信息时返回的模型 (包含 ID)"""
|
||||
id: str = Field(..., description="助手唯一 ID")
|
||||
|
||||
created_at: datetime # Add timestamps
|
||||
updated_at: Optional[datetime] = None
|
||||
class Config:
|
||||
from_attributes = True # Pydantic v2: orm_mode = True
|
||||
|
||||
@ -70,7 +72,22 @@ class SessionRead(BaseModel):
|
||||
id: str
|
||||
title: str
|
||||
assistant_id: str
|
||||
created_at: str
|
||||
created_at: datetime # Use datetime
|
||||
updated_at: Optional[datetime] = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
# --- Message Models (New) ---
|
||||
class MessageBase(BaseModel):
|
||||
sender: str # 'user' or 'ai'
|
||||
text: str
|
||||
|
||||
class MessageRead(MessageBase):
|
||||
id: str
|
||||
session_id: str
|
||||
order: int
|
||||
created_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
@ -1,73 +1,76 @@
|
||||
# File: backend/app/services/assistant_service.py (新建)
|
||||
# Description: 管理助手数据的服务 (内存实现)
|
||||
# File: backend/app/services/assistant_service.py (Update with DB)
|
||||
# Description: 管理助手数据的服务 (使用 SQLAlchemy)
|
||||
|
||||
from typing import Dict, List, Optional
|
||||
from app.models.pydantic_models import AssistantRead, AssistantCreate, AssistantUpdate
|
||||
import uuid
|
||||
|
||||
# 使用字典作为内存数据库存储助手
|
||||
# key: assistant_id (str), value: AssistantRead object
|
||||
assistants_db: Dict[str, AssistantRead] = {}
|
||||
|
||||
# 添加默认助手 (确保 ID 与前端 Mock 一致)
|
||||
default_assistant = AssistantRead(
|
||||
id='asst-default',
|
||||
name='默认助手',
|
||||
description='通用聊天助手',
|
||||
avatar='🤖',
|
||||
system_prompt='你是一个乐于助人的 AI 助手。',
|
||||
model='gpt-3.5-turbo',
|
||||
temperature=0.7
|
||||
)
|
||||
assistants_db[default_assistant.id] = default_assistant
|
||||
from typing import List, Optional
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy import update as sqlalchemy_update, delete as sqlalchemy_delete
|
||||
from app.db.models import AssistantModel
|
||||
from app.models.pydantic_models import AssistantCreate, AssistantUpdate, AssistantRead
|
||||
|
||||
class AssistantService:
|
||||
"""助手数据的 CRUD 服务"""
|
||||
"""助手数据的 CRUD 服务 (数据库版)"""
|
||||
|
||||
def get_assistants(self) -> List[AssistantRead]:
|
||||
async def get_assistants(self, db: AsyncSession) -> List[AssistantRead]:
|
||||
"""获取所有助手"""
|
||||
return list(assistants_db.values())
|
||||
result = await db.execute(select(AssistantModel).order_by(AssistantModel.name))
|
||||
assistants = result.scalars().all()
|
||||
return [AssistantRead.model_validate(a) for a in assistants] # Use model_validate in Pydantic v2
|
||||
|
||||
def get_assistant(self, assistant_id: str) -> Optional[AssistantRead]:
|
||||
async def get_assistant(self, db: AsyncSession, assistant_id: str) -> Optional[AssistantRead]:
|
||||
"""根据 ID 获取单个助手"""
|
||||
return assistants_db.get(assistant_id)
|
||||
result = await db.execute(select(AssistantModel).filter(AssistantModel.id == assistant_id))
|
||||
assistant = result.scalars().first()
|
||||
return AssistantRead.model_validate(assistant) if assistant else None
|
||||
|
||||
def create_assistant(self, assistant_data: AssistantCreate) -> AssistantRead:
|
||||
async def create_assistant(self, db: AsyncSession, assistant_data: AssistantCreate) -> AssistantRead:
|
||||
"""创建新助手"""
|
||||
new_id = f"asst-{uuid.uuid4()}" # 生成唯一 ID
|
||||
new_assistant = AssistantRead(id=new_id, **assistant_data.model_dump())
|
||||
assistants_db[new_id] = new_assistant
|
||||
print(f"助手已创建: {new_id} - {new_assistant.name}")
|
||||
return new_assistant
|
||||
# 使用 Pydantic 模型创建 DB 模型实例
|
||||
db_assistant = AssistantModel(**assistant_data.model_dump())
|
||||
# ID will be generated by default in the model
|
||||
db.add(db_assistant)
|
||||
await db.flush() # Flush to get the generated ID and defaults
|
||||
await db.refresh(db_assistant) # Refresh to load all attributes
|
||||
print(f"助手已创建 (DB): {db_assistant.id} - {db_assistant.name}")
|
||||
return AssistantRead.model_validate(db_assistant)
|
||||
|
||||
def update_assistant(self, assistant_id: str, assistant_data: AssistantUpdate) -> Optional[AssistantRead]:
|
||||
async def update_assistant(self, db: AsyncSession, assistant_id: str, assistant_data: AssistantUpdate) -> Optional[AssistantRead]:
|
||||
"""更新现有助手"""
|
||||
existing_assistant = assistants_db.get(assistant_id)
|
||||
if not existing_assistant:
|
||||
return None
|
||||
update_values = assistant_data.model_dump(exclude_unset=True)
|
||||
if not update_values:
|
||||
# If nothing to update, just fetch and return the existing one
|
||||
return await self.get_assistant(db, assistant_id)
|
||||
|
||||
# 使用 Pydantic 的 model_copy 和 update 来更新字段
|
||||
update_data = assistant_data.model_dump(exclude_unset=True) # 只获取设置了值的字段
|
||||
if update_data:
|
||||
updated_assistant = existing_assistant.model_copy(update=update_data)
|
||||
assistants_db[assistant_id] = updated_assistant
|
||||
print(f"助手已更新: {assistant_id}")
|
||||
return updated_assistant
|
||||
return existing_assistant # 如果没有更新任何字段,返回原始助手
|
||||
# Execute update statement
|
||||
stmt = (
|
||||
sqlalchemy_update(AssistantModel)
|
||||
.where(AssistantModel.id == assistant_id)
|
||||
.values(**update_values)
|
||||
.returning(AssistantModel) # Return the updated row
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
updated_assistant = result.scalars().first()
|
||||
|
||||
def delete_assistant(self, assistant_id: str) -> bool:
|
||||
if updated_assistant:
|
||||
await db.flush()
|
||||
await db.refresh(updated_assistant)
|
||||
print(f"助手已更新 (DB): {assistant_id}")
|
||||
return AssistantRead.model_validate(updated_assistant)
|
||||
return None # Assistant not found
|
||||
|
||||
async def delete_assistant(self, db: AsyncSession, assistant_id: str) -> bool:
|
||||
"""删除助手"""
|
||||
if assistant_id in assistants_db:
|
||||
# 添加逻辑:不允许删除默认助手
|
||||
if assistant_id == 'asst-default':
|
||||
# Prevent deleting default assistant
|
||||
if assistant_id == 'asst-default': # Assuming 'asst-default' is a known ID
|
||||
print("尝试删除默认助手 - 操作被阻止")
|
||||
return False # 或者抛出特定异常
|
||||
del assistants_db[assistant_id]
|
||||
print(f"助手已删除: {assistant_id}")
|
||||
# TODO: 在实际应用中,还需要删除关联的会话和消息
|
||||
return False
|
||||
|
||||
stmt = sqlalchemy_delete(AssistantModel).where(AssistantModel.id == assistant_id)
|
||||
result = await db.execute(stmt)
|
||||
if result.rowcount > 0:
|
||||
await db.flush()
|
||||
print(f"助手已删除 (DB): {assistant_id}")
|
||||
# Deletion of sessions/messages handled by cascade="all, delete-orphan"
|
||||
return True
|
||||
return False
|
||||
|
||||
# 创建服务实例
|
||||
assistant_service_instance = AssistantService()
|
||||
|
||||
|
||||
@ -1,127 +1,137 @@
|
||||
# File: backend/app/services/chat_service.py (更新)
|
||||
# Description: 封装 LangChain 聊天逻辑 (支持助手配置和会话历史)
|
||||
# File: backend/app/services/chat_service.py (Update with DB for history)
|
||||
# Description: 封装 LangChain 聊天逻辑 (使用数据库存储和检索消息)
|
||||
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
from langchain_core.messages import HumanMessage, AIMessage, BaseMessage, SystemMessage
|
||||
from typing import Dict, List, Optional
|
||||
from app.services.assistant_service import assistant_service_instance # 获取助手配置
|
||||
from app.models.pydantic_models import AssistantRead # 引入助手模型
|
||||
import app.core.config as Config
|
||||
|
||||
# --- 更新内存管理 ---
|
||||
# 使用字典存储不同会话的内存
|
||||
# key: session_id (str), value: List[BaseMessage]
|
||||
chat_history_db: Dict[str, List[BaseMessage]] = {}
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from app.db.models import MessageModel, AssistantModel # Import DB models
|
||||
from app.services.assistant_service import AssistantService # Use class directly
|
||||
from app.models.pydantic_models import AssistantRead
|
||||
|
||||
class ChatService:
|
||||
"""处理 AI 聊天交互的服务 (支持助手配置)"""
|
||||
"""处理 AI 聊天交互的服务 (使用数据库历史)"""
|
||||
|
||||
def __init__(self, default_api_key: str):
|
||||
"""初始化时可传入默认 API Key"""
|
||||
self.default_api_key = default_api_key
|
||||
# 不再在 init 中创建固定的 LLM 和 chain
|
||||
self.assistant_service = AssistantService() # Instantiate assistant service
|
||||
|
||||
def _get_llm(self, assistant: AssistantRead) -> ChatOpenAI:
|
||||
"""根据助手配置动态创建 LLM 实例"""
|
||||
# TODO: 支持不同模型提供商 (Gemini, Anthropic etc.)
|
||||
# ... (remains the same) ...
|
||||
if assistant.model.startswith("gpt"):
|
||||
return ChatOpenAI(
|
||||
model=assistant.model,
|
||||
api_key=self.default_api_key, # 或从助手配置中读取特定 key
|
||||
temperature=assistant.temperature
|
||||
)
|
||||
return ChatOpenAI(model=assistant.model, api_key=self.default_api_key, temperature=assistant.temperature)
|
||||
elif assistant.model.startswith("gemini"):
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
return ChatGoogleGenerativeAI(
|
||||
model=assistant.model,
|
||||
api_key=self.default_api_key, # 或从助手配置中读取特定 key
|
||||
temperature=assistant.temperature
|
||||
)
|
||||
else:
|
||||
# 默认或抛出错误
|
||||
print(f"警告: 模型 {assistant.model} 未明确支持,尝试使用 ChatOpenAI")
|
||||
return ChatOpenAI(
|
||||
model=assistant.model,
|
||||
api_key=self.default_api_key,
|
||||
temperature=assistant.temperature
|
||||
)
|
||||
return ChatOpenAI(model=assistant.model, api_key=self.default_api_key, temperature=assistant.temperature)
|
||||
|
||||
async def get_ai_reply(self, user_message: str, session_id: str, assistant_id: str) -> str:
|
||||
"""
|
||||
获取 AI 对用户消息的回复 (使用指定助手和会话历史)
|
||||
Args:
|
||||
user_message (str): 用户发送的消息
|
||||
session_id (str): 会话 ID
|
||||
assistant_id (str): 使用的助手 ID
|
||||
Returns:
|
||||
str: AI 的回复文本
|
||||
Raises:
|
||||
ValueError: 如果找不到指定的助手
|
||||
Exception: 如果调用 AI 服务时发生错误
|
||||
"""
|
||||
|
||||
async def _get_chat_history(self, db: AsyncSession, session_id: str, limit: int = 10) -> Tuple[List[BaseMessage], int]:
|
||||
"""从数据库加载指定会话的历史消息 (按 order 排序)"""
|
||||
stmt = (
|
||||
select(MessageModel)
|
||||
.filter(MessageModel.session_id == session_id)
|
||||
.order_by(MessageModel.order.desc()) # Get latest first
|
||||
.limit(limit)
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
db_messages = result.scalars().all()
|
||||
|
||||
# Convert to LangChain messages (in correct order: oldest to newest)
|
||||
history: List[BaseMessage] = []
|
||||
max_order = 0
|
||||
for msg in reversed(db_messages): # Reverse to get chronological order
|
||||
if msg.sender == 'user':
|
||||
history.append(HumanMessage(content=msg.text))
|
||||
elif msg.sender == 'ai':
|
||||
history.append(AIMessage(content=msg.text))
|
||||
# Add handling for 'system' if needed
|
||||
max_order = max(max_order, msg.order) # Keep track of the latest order number
|
||||
|
||||
return history, max_order
|
||||
|
||||
async def _save_message(self, db: AsyncSession, session_id: str, sender: str, text: str, order: int):
|
||||
"""将消息保存到数据库"""
|
||||
db_message = MessageModel(
|
||||
session_id=session_id,
|
||||
sender=sender,
|
||||
text=text,
|
||||
order=order
|
||||
)
|
||||
db.add(db_message)
|
||||
await db.flush() # Ensure it's added before potential commit
|
||||
print(f"消息已保存 (DB): Session={session_id}, Order={order}, Sender={sender}")
|
||||
|
||||
|
||||
async def get_ai_reply(self, db: AsyncSession, user_message: str, session_id: str, assistant_id: str) -> str:
|
||||
"""获取 AI 回复,并保存用户消息和 AI 回复到数据库"""
|
||||
# 1. 获取助手配置
|
||||
assistant = assistant_service_instance.get_assistant(assistant_id)
|
||||
assistant = await self.assistant_service.get_assistant(db, assistant_id)
|
||||
if not assistant:
|
||||
raise ValueError(f"找不到助手 ID: {assistant_id}")
|
||||
|
||||
# 2. 获取或初始化当前会话的历史记录
|
||||
current_chat_history = chat_history_db.get(session_id, [])
|
||||
# 2. 获取历史记录和下一个序号
|
||||
current_chat_history, last_order = await self._get_chat_history(db, session_id)
|
||||
user_message_order = last_order + 1
|
||||
ai_message_order = last_order + 2
|
||||
|
||||
# 3. 构建 Prompt (包含动态系统提示)
|
||||
# 3. 构建 Prompt
|
||||
prompt = ChatPromptTemplate.from_messages([
|
||||
SystemMessage(content=assistant.system_prompt), # 使用助手的系统提示
|
||||
SystemMessage(content=assistant.system_prompt),
|
||||
MessagesPlaceholder(variable_name="chat_history"),
|
||||
HumanMessage(content="{input}"),
|
||||
HumanMessage(content=user_message),
|
||||
])
|
||||
|
||||
# 4. 获取 LLM 实例
|
||||
# 4. 获取 LLM
|
||||
llm = self._get_llm(assistant)
|
||||
|
||||
# 5. 定义输出解析器
|
||||
output_parser = StrOutputParser()
|
||||
|
||||
# 6. 构建 LCEL 链
|
||||
chain = prompt | llm | output_parser
|
||||
|
||||
try:
|
||||
# 7. 调用链获取回复
|
||||
# --- Save user message BEFORE calling LLM ---
|
||||
await self._save_message(db, session_id, 'user', user_message, user_message_order)
|
||||
|
||||
# 5. 调用链获取回复
|
||||
ai_response = await chain.ainvoke({
|
||||
"input": user_message,
|
||||
"chat_history": current_chat_history,
|
||||
"chat_history": current_chat_history, # Pass history fetched from DB
|
||||
})
|
||||
|
||||
# 8. 更新会话历史记录
|
||||
current_chat_history.append(HumanMessage(content=user_message))
|
||||
current_chat_history.append(AIMessage(content=ai_response))
|
||||
# 限制历史记录长度 (例如最近 10 条消息)
|
||||
max_history_length = 10
|
||||
if len(current_chat_history) > max_history_length:
|
||||
chat_history_db[session_id] = current_chat_history[-max_history_length:]
|
||||
else:
|
||||
chat_history_db[session_id] = current_chat_history
|
||||
# --- Save AI response AFTER getting it ---
|
||||
await self._save_message(db, session_id, 'ai', ai_response, ai_message_order)
|
||||
|
||||
# Note: We don't need to manage history in memory anymore (chat_history_db removed)
|
||||
return ai_response
|
||||
|
||||
except Exception as e:
|
||||
# Consider rolling back the user message save if LLM call fails,
|
||||
# although often it's better to keep the user message.
|
||||
# await db.rollback() # Handled by get_db_session dependency on error
|
||||
print(f"调用 LangChain 时出错 (助手: {assistant_id}, 会话: {session_id}): {e}")
|
||||
raise Exception(f"AI 服务暂时不可用: {e}")
|
||||
|
||||
# (可选) 添加一个简单的文本生成方法用于生成标题
|
||||
async def generate_text(self, prompt_text: str, model_name: str = "gpt-3.5-turbo", temperature: float = 0.5) -> str:
|
||||
"""使用指定模型生成文本 (用于标题等)"""
|
||||
async def generate_text(self, prompt_text: str, model_name: str = "gemini-2.0-flash", temperature: float = 0.5) -> str:
|
||||
# ... (remains the same) ...
|
||||
try:
|
||||
# 使用一个临时的、可能更便宜的模型
|
||||
temp_llm = ChatOpenAI(model=model_name, api_key=self.default_api_key, temperature=temperature)
|
||||
temp_llm = ChatGoogleGenerativeAI(
|
||||
model=model_name,
|
||||
api_key=self.default_api_key, # 或从助手配置中读取特定 key
|
||||
temperature=temperature
|
||||
)
|
||||
response = await temp_llm.ainvoke(prompt_text)
|
||||
return response.content
|
||||
except Exception as e:
|
||||
print(f"生成文本时出错: {e}")
|
||||
return "无法生成标题" # 返回默认值或抛出异常
|
||||
return "无法生成标题"
|
||||
|
||||
|
||||
# --- 创建 ChatService 实例 ---
|
||||
if not Config.GOOGLE_API_KEY:
|
||||
raise ValueError("请在 .env 文件中设置 OPENAI_API_KEY")
|
||||
chat_service_instance = ChatService(default_api_key=Config.GOOGLE_API_KEY)
|
||||
# ChatService instance is now created where needed or injected, no global instance here.
|
||||
|
||||
@ -1,82 +1,74 @@
|
||||
# File: backend/app/services/session_service.py (新建)
|
||||
# Description: 管理会话数据的服务 (内存实现)
|
||||
# File: backend/app/services/session_service.py (Update with DB)
|
||||
# Description: 管理会话数据的服务 (使用 SQLAlchemy)
|
||||
|
||||
from typing import Dict, List, Optional
|
||||
from app.models.pydantic_models import SessionRead, SessionCreateRequest, SessionCreateResponse, AssistantRead
|
||||
from app.services.assistant_service import assistant_service_instance # 需要获取助手信息
|
||||
from typing import List, Optional
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy import delete as sqlalchemy_delete
|
||||
from app.db.models import SessionModel, AssistantModel # Import DB models
|
||||
from app.models.pydantic_models import SessionRead, SessionCreateRequest, SessionCreateResponse
|
||||
from datetime import datetime, timezone
|
||||
import uuid
|
||||
# 导入 ChatService 以调用 LLM 生成标题 (避免循环导入,考虑重构)
|
||||
# from app.services.chat_service import chat_service_instance
|
||||
|
||||
# 内存数据库存储会话
|
||||
# key: session_id (str), value: SessionRead object
|
||||
sessions_db: Dict[str, SessionRead] = {}
|
||||
|
||||
# Import ChatService for title generation (consider refactoring later)
|
||||
from app.services.chat_service import ChatService
|
||||
import app.core.config as Config
|
||||
chat_service_instance = ChatService(Config.GOOGLE_API_KEY)
|
||||
class SessionService:
|
||||
"""会话数据的 CRUD 及标题生成服务"""
|
||||
"""会话数据的 CRUD 及标题生成服务 (数据库版)"""
|
||||
|
||||
async def create_session(self, session_data: SessionCreateRequest) -> SessionCreateResponse:
|
||||
async def create_session(self, db: AsyncSession, session_data: SessionCreateRequest) -> SessionCreateResponse:
|
||||
"""创建新会话并生成标题"""
|
||||
assistant = assistant_service_instance.get_assistant(session_data.assistant_id)
|
||||
# 检查助手是否存在
|
||||
result = await db.execute(select(AssistantModel).filter(AssistantModel.id == session_data.assistant_id))
|
||||
assistant = result.scalars().first()
|
||||
if not assistant:
|
||||
raise ValueError("指定的助手不存在")
|
||||
|
||||
new_id = f"session-{uuid.uuid4()}"
|
||||
created_time = datetime.now(timezone.utc)
|
||||
# --- 调用 LLM 生成标题 ---
|
||||
try:
|
||||
title_prompt = f"根据以下用户第一条消息,为此对话生成一个简洁的标题(不超过10个字):\n\n{session_data.first_message}"
|
||||
generated_title = await chat_service_instance.generate_text(title_prompt)
|
||||
except Exception as e:
|
||||
print(f"生成会话标题时出错: {e}")
|
||||
generated_title = f"关于 \"{session_data.first_message[:15]}...\"" # Fallback
|
||||
# --- 生成结束 ---
|
||||
|
||||
# --- TODO: 调用 LLM 生成标题 ---
|
||||
# title_prompt = f"根据以下用户第一条消息,为此对话生成一个简洁的标题(不超过10个字):\n\n{session_data.first_message}"
|
||||
# generated_title = await chat_service_instance.generate_text(title_prompt) # 需要一个简单的文本生成方法
|
||||
|
||||
# 模拟标题生成
|
||||
generated_title = f"关于 \"{session_data.first_message[:15]}...\""
|
||||
print(f"为新会话 {new_id} 生成标题: {generated_title}")
|
||||
# --- 模拟结束 ---
|
||||
|
||||
new_session = SessionRead(
|
||||
id=new_id,
|
||||
db_session = SessionModel(
|
||||
title=generated_title,
|
||||
assistant_id=session_data.assistant_id,
|
||||
created_at=created_time.isoformat() # 存储 ISO 格式字符串
|
||||
assistant_id=session_data.assistant_id
|
||||
# ID and created_at have defaults
|
||||
)
|
||||
sessions_db[new_id] = new_session
|
||||
print(f"会话已创建: {new_id}")
|
||||
db.add(db_session)
|
||||
await db.flush()
|
||||
await db.refresh(db_session)
|
||||
print(f"会话已创建 (DB): {db_session.id}")
|
||||
|
||||
return SessionCreateResponse(
|
||||
id=new_session.id,
|
||||
title=new_session.title,
|
||||
assistant_id=new_session.assistant_id,
|
||||
created_at=new_session.created_at
|
||||
id=db_session.id,
|
||||
title=db_session.title,
|
||||
assistant_id=db_session.assistant_id,
|
||||
created_at=db_session.created_at.isoformat() # Use datetime from DB model
|
||||
)
|
||||
|
||||
def get_sessions_by_assistant(self, assistant_id: str) -> List[SessionRead]:
|
||||
async def get_sessions_by_assistant(self, db: AsyncSession, assistant_id: str) -> List[SessionRead]:
|
||||
"""获取指定助手的所有会话"""
|
||||
return [s for s in sessions_db.values() if s.assistant_id == assistant_id]
|
||||
stmt = select(SessionModel).filter(SessionModel.assistant_id == assistant_id).order_by(SessionModel.updated_at.desc()) # Order by update time
|
||||
result = await db.execute(stmt)
|
||||
sessions = result.scalars().all()
|
||||
return [SessionRead.model_validate(s) for s in sessions]
|
||||
|
||||
def get_session(self, session_id: str) -> Optional[SessionRead]:
|
||||
async def get_session(self, db: AsyncSession, session_id: str) -> Optional[SessionRead]:
|
||||
"""获取单个会话"""
|
||||
return sessions_db.get(session_id)
|
||||
result = await db.execute(select(SessionModel).filter(SessionModel.id == session_id))
|
||||
session = result.scalars().first()
|
||||
return SessionRead.model_validate(session) if session else None
|
||||
|
||||
def delete_session(self, session_id: str) -> bool:
|
||||
async def delete_session(self, db: AsyncSession, session_id: str) -> bool:
|
||||
"""删除会话"""
|
||||
if session_id in sessions_db:
|
||||
del sessions_db[session_id]
|
||||
print(f"会话已删除: {session_id}")
|
||||
# TODO: 删除关联的消息
|
||||
stmt = sqlalchemy_delete(SessionModel).where(SessionModel.id == session_id)
|
||||
result = await db.execute(stmt)
|
||||
if result.rowcount > 0:
|
||||
await db.flush()
|
||||
print(f"会话已删除 (DB): {session_id}")
|
||||
# Deletion of messages handled by cascade
|
||||
return True
|
||||
return False
|
||||
|
||||
def delete_sessions_by_assistant(self, assistant_id: str) -> int:
|
||||
"""删除指定助手的所有会话"""
|
||||
ids_to_delete = [s.id for s in sessions_db.values() if s.assistant_id == assistant_id]
|
||||
count = 0
|
||||
for session_id in ids_to_delete:
|
||||
if self.delete_session(session_id):
|
||||
count += 1
|
||||
print(f"删除了助手 {assistant_id} 的 {count} 个会话")
|
||||
return count
|
||||
|
||||
|
||||
# 创建服务实例
|
||||
session_service_instance = SessionService()
|
||||
|
||||
BIN
backend/cherryai.db
Normal file
BIN
backend/cherryai.db
Normal file
Binary file not shown.
@ -1,38 +1,82 @@
|
||||
// File: frontend/app/chat/page.tsx (更新以使用 API)
|
||||
// Description: 对接后端 API 实现助手和会话的加载与管理
|
||||
|
||||
'use client';
|
||||
"use client";
|
||||
|
||||
import React, { useState, useRef, useEffect, useCallback } from 'react';
|
||||
import { SendHorizontal, Loader2, PanelRightOpen, PanelRightClose, UserPlus, Settings2, Trash2, Edit, RefreshCw } from 'lucide-react'; // 添加刷新图标
|
||||
import React, { useState, useRef, useEffect, useCallback } from "react";
|
||||
import {
|
||||
SendHorizontal,
|
||||
Loader2,
|
||||
PanelRightOpen,
|
||||
PanelRightClose,
|
||||
UserPlus,
|
||||
Settings2,
|
||||
Trash2,
|
||||
Edit,
|
||||
RefreshCw,
|
||||
} from "lucide-react"; // 添加刷新图标
|
||||
import { useForm } from "react-hook-form";
|
||||
import { zodResolver } from "@hookform/resolvers/zod";
|
||||
import * as z from "zod";
|
||||
|
||||
// Shadcn UI Components
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Dialog, DialogContent, DialogDescription, DialogHeader, DialogTitle, DialogTrigger, DialogFooter, DialogClose } from "@/components/ui/dialog";
|
||||
import { Form, FormControl, FormDescription, FormField, FormItem, FormLabel, FormMessage } from "@/components/ui/form";
|
||||
import {
|
||||
Dialog,
|
||||
DialogContent,
|
||||
DialogDescription,
|
||||
DialogHeader,
|
||||
DialogTitle,
|
||||
DialogTrigger,
|
||||
DialogFooter,
|
||||
DialogClose,
|
||||
} from "@/components/ui/dialog";
|
||||
import {
|
||||
Form,
|
||||
FormControl,
|
||||
FormDescription,
|
||||
FormField,
|
||||
FormItem,
|
||||
FormLabel,
|
||||
FormMessage,
|
||||
} from "@/components/ui/form";
|
||||
import { Input } from "@/components/ui/input";
|
||||
import { Textarea } from "@/components/ui/textarea";
|
||||
import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from "@/components/ui/select";
|
||||
import {
|
||||
Select,
|
||||
SelectContent,
|
||||
SelectItem,
|
||||
SelectTrigger,
|
||||
SelectValue,
|
||||
} from "@/components/ui/select";
|
||||
import { Slider } from "@/components/ui/slider";
|
||||
import { Toaster, toast } from "sonner";
|
||||
import { Skeleton } from "@/components/ui/skeleton"; // 导入骨架屏
|
||||
|
||||
// API 函数和类型
|
||||
import {
|
||||
sendChatMessage, getAssistants, createAssistant, updateAssistant, deleteAssistant,
|
||||
getSessionsByAssistant, deleteSession,
|
||||
Assistant, Session, AssistantCreateData, AssistantUpdateData, ChatApiResponse
|
||||
} from '@/lib/api'; // 确保路径正确
|
||||
sendChatMessage,
|
||||
getAssistants,
|
||||
createAssistant,
|
||||
updateAssistant,
|
||||
deleteAssistant,
|
||||
getSessionsByAssistant,
|
||||
deleteSession,
|
||||
getMessagesBySession,
|
||||
Session,
|
||||
ChatApiResponse,
|
||||
Message as ApiMessage,
|
||||
} from "@/lib/api"; // 确保路径正确
|
||||
import {
|
||||
Assistant,
|
||||
AssistantCreateData,
|
||||
AssistantUpdateData,
|
||||
} from "@/types/assistant";
|
||||
|
||||
// --- 数据接口定义 ---
|
||||
interface Message {
|
||||
id: string;
|
||||
text: string;
|
||||
sender: 'user' | 'ai';
|
||||
isError?: boolean;
|
||||
// --- Frontend specific Message type (includes optional isError) ---
|
||||
interface Message extends ApiMessage {
|
||||
// Extend the type from API
|
||||
isError?: boolean; // Optional flag for frontend error styling
|
||||
}
|
||||
|
||||
interface ChatSession {
|
||||
@ -45,10 +89,16 @@ interface ChatSession {
|
||||
|
||||
// --- Zod Schema for Assistant Form Validation ---
|
||||
const assistantFormSchema = z.object({
|
||||
name: z.string().min(1, { message: "助手名称不能为空" }).max(50, { message: "名称过长" }),
|
||||
name: z
|
||||
.string()
|
||||
.min(1, { message: "助手名称不能为空" })
|
||||
.max(50, { message: "名称过长" }),
|
||||
description: z.string().max(200, { message: "描述过长" }).optional(),
|
||||
avatar: z.string().max(5, { message: "头像/Emoji 过长" }).optional(), // 简单限制长度
|
||||
system_prompt: z.string().min(1, { message: "系统提示不能为空" }).max(4000, { message: "系统提示过长" }),
|
||||
system_prompt: z
|
||||
.string()
|
||||
.min(1, { message: "系统提示不能为空" })
|
||||
.max(4000, { message: "系统提示过长" }),
|
||||
model: z.string({ required_error: "请选择一个模型" }),
|
||||
temperature: z.number().min(0).max(1),
|
||||
});
|
||||
@ -66,9 +116,12 @@ const availableModels = [
|
||||
];
|
||||
|
||||
// --- Helper Function ---
|
||||
const findLastSession = (sessions: ChatSession[], assistantId: string): ChatSession | undefined => {
|
||||
const findLastSession = (
|
||||
sessions: ChatSession[],
|
||||
assistantId: string
|
||||
): ChatSession | undefined => {
|
||||
return sessions
|
||||
.filter(s => s.assistantId === assistantId && !s.isTemporary)
|
||||
.filter((s) => s.assistantId === assistantId && !s.isTemporary)
|
||||
.sort((a, b) => b.createdAt.getTime() - a.createdAt.getTime())[0];
|
||||
};
|
||||
|
||||
@ -182,7 +235,7 @@ function AssistantForm({ assistant, onSave, onClose }: AssistantFormProps) {
|
||||
</SelectTrigger>
|
||||
</FormControl>
|
||||
<SelectContent>
|
||||
{availableModels.map(model => (
|
||||
{availableModels.map((model) => (
|
||||
<SelectItem key={model.value} value={model.value}>
|
||||
{model.label}
|
||||
</SelectItem>
|
||||
@ -199,7 +252,9 @@ function AssistantForm({ assistant, onSave, onClose }: AssistantFormProps) {
|
||||
name="temperature"
|
||||
render={({ field }) => (
|
||||
<FormItem>
|
||||
<FormLabel>温度 (Temperature): {field.value.toFixed(1)}</FormLabel>
|
||||
<FormLabel>
|
||||
温度 (Temperature): {field.value.toFixed(1)}
|
||||
</FormLabel>
|
||||
<FormControl>
|
||||
{/* Shadcn Slider expects an array for value */}
|
||||
<Slider
|
||||
@ -211,20 +266,22 @@ function AssistantForm({ assistant, onSave, onClose }: AssistantFormProps) {
|
||||
className="py-2" // Add padding for better interaction
|
||||
/>
|
||||
</FormControl>
|
||||
<FormDescription>
|
||||
值越低越稳定,越高越有创造性。
|
||||
</FormDescription>
|
||||
<FormDescription>值越低越稳定,越高越有创造性。</FormDescription>
|
||||
<FormMessage />
|
||||
</FormItem>
|
||||
)}
|
||||
/>
|
||||
<DialogFooter>
|
||||
<DialogClose asChild>
|
||||
<Button type="button" variant="outline" disabled={isSaving}>取消</Button>
|
||||
<Button type="button" variant="outline" disabled={isSaving}>
|
||||
取消
|
||||
</Button>
|
||||
</DialogClose>
|
||||
<Button type="submit" disabled={isSaving}>
|
||||
{isSaving ? <Loader2 className="mr-2 h-4 w-4 animate-spin" /> : null}
|
||||
{isSaving ? '保存中...' : '保存助手'}
|
||||
{isSaving ? (
|
||||
<Loader2 className="mr-2 h-4 w-4 animate-spin" />
|
||||
) : null}
|
||||
{isSaving ? "保存中..." : "保存助手"}
|
||||
</Button>
|
||||
</DialogFooter>
|
||||
</form>
|
||||
@ -232,11 +289,11 @@ function AssistantForm({ assistant, onSave, onClose }: AssistantFormProps) {
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
// --- Main Chat Page Component ---
|
||||
export default function ChatPage() {
|
||||
// --- State Variables ---
|
||||
const [inputMessage, setInputMessage] = useState('');
|
||||
const [inputMessage, setInputMessage] = useState("");
|
||||
// Messages state now holds Message type from API
|
||||
const [messages, setMessages] = useState<Message[]>([]);
|
||||
|
||||
const [isLoading, setIsLoading] = useState(false); // AI 回复加载状态
|
||||
@ -244,15 +301,19 @@ export default function ChatPage() {
|
||||
|
||||
const [isSessionPanelOpen, setIsSessionPanelOpen] = useState(true);
|
||||
const [isAssistantDialogOpen, setIsAssistantDialogOpen] = useState(false); // 控制助手表单 Dialog 显隐
|
||||
const [editingAssistant, setEditingAssistant] = useState<Assistant | null>(null); // 当前正在编辑的助手
|
||||
const [editingAssistant, setEditingAssistant] = useState<Assistant | null>(
|
||||
null
|
||||
); // 当前正在编辑的助手
|
||||
|
||||
// Data Loading States
|
||||
const [assistantsLoading, setAssistantsLoading] = useState(true);
|
||||
const [sessionsLoading, setSessionsLoading] = useState(false);
|
||||
|
||||
const [messagesLoading, setMessagesLoading] = useState(false);
|
||||
// Data State
|
||||
const [assistants, setAssistants] = useState<Assistant[]>([]);
|
||||
const [currentAssistantId, setCurrentAssistantId] = useState<string | null>(null); // 初始为 null
|
||||
const [currentAssistantId, setCurrentAssistantId] = useState<string | null>(
|
||||
null
|
||||
); // 初始为 null
|
||||
const [allSessions, setAllSessions] = useState<Session[]>([]);
|
||||
const [currentSessionId, setCurrentSessionId] = useState<string | null>(null); // 初始为 null
|
||||
|
||||
@ -261,19 +322,20 @@ export default function ChatPage() {
|
||||
|
||||
// --- Effects ---
|
||||
// Initial data loading (Assistants)
|
||||
// Initial Assistant loading
|
||||
useEffect(() => {
|
||||
const loadAssistants = async () => {
|
||||
setAssistantsLoading(true);
|
||||
try {
|
||||
const fetchedAssistants = await getAssistants();
|
||||
setAssistants(fetchedAssistants);
|
||||
// 设置默认选中的助手 (例如第一个或 ID 为 'asst-default' 的)
|
||||
const defaultAssistant = fetchedAssistants.find(a => a.id === 'asst-default') || fetchedAssistants[0];
|
||||
const defaultAssistant =
|
||||
fetchedAssistants.find((a) => a.id === "asst-default") ||
|
||||
fetchedAssistants[0];
|
||||
if (defaultAssistant) {
|
||||
setCurrentAssistantId(defaultAssistant.id);
|
||||
} else {
|
||||
console.warn("No default or initial assistant found.");
|
||||
// 可能需要提示用户创建助手
|
||||
}
|
||||
} catch (apiError: any) {
|
||||
toast.error(`加载助手列表失败: ${apiError.message}`);
|
||||
@ -283,63 +345,92 @@ export default function ChatPage() {
|
||||
}
|
||||
};
|
||||
loadAssistants();
|
||||
}, []); // 空依赖数组,只在挂载时运行一次
|
||||
|
||||
// Load sessions when assistant changes
|
||||
}, []);
|
||||
// Load sessions when assistant changes (remains same, but calls handleSelectSession internally)
|
||||
useEffect(() => {
|
||||
if (!currentAssistantId) return; // 如果没有选中助手,则不加载
|
||||
|
||||
if (!currentAssistantId) return;
|
||||
const loadSessions = async () => {
|
||||
setSessionsLoading(true);
|
||||
// 清空当前会话和消息列表
|
||||
setCurrentSessionId(null);
|
||||
setMessages([]);
|
||||
try {
|
||||
const fetchedSessions = await getSessionsByAssistant(currentAssistantId);
|
||||
// 更新全局会话列表 (只保留其他助手的会话,加上当前助手的)
|
||||
// Filter out sessions that might belong to a deleted assistant still in cache
|
||||
const validAssistants = new Set(assistants.map(a => a.id));
|
||||
setAllSessions(prev => [
|
||||
...prev.filter(s => s.assistant_id !== currentAssistantId),
|
||||
...prev.filter(s => s.assistant_id !== currentAssistantId && validAssistants.has(s.assistant_id)),
|
||||
...fetchedSessions
|
||||
]);
|
||||
|
||||
// 查找最新的会话并设为当前
|
||||
const lastSession = fetchedSessions
|
||||
.sort((a, b) => new Date(b.created_at).getTime() - new Date(a.created_at).getTime())[0];
|
||||
|
||||
const lastSession = fetchedSessions.sort((a, b) => new Date(b.created_at).getTime() - new Date(a.created_at).getTime())[0];
|
||||
if (lastSession) {
|
||||
setCurrentSessionId(lastSession.id);
|
||||
// TODO: 加载 lastSession.id 的历史消息
|
||||
console.log(`加载助手 ${currentAssistantId} 的最后一个会话: ${lastSession.id}`);
|
||||
const currentAssistant = assistants.find(a => a.id === currentAssistantId);
|
||||
setMessages([ { id: `init-${lastSession.id}-1`, text: `继续与 ${currentAssistant?.name || '助手'} 的对话: ${lastSession.title}`, sender: 'ai' } ]);
|
||||
setCurrentSessionId(lastSession.id); // Trigger message loading effect
|
||||
} else {
|
||||
// 没有历史会话,进入临时新对话状态
|
||||
setCurrentSessionId('temp-new-chat');
|
||||
console.log(`助手 ${currentAssistantId} 没有历史会话,创建临时新对话`);
|
||||
const currentAssistant = assistants.find(a => a.id === currentAssistantId);
|
||||
setMessages([ { id: `init-temp-${currentAssistantId}`, text: `开始与 ${currentAssistant?.name || '助手'} 的新对话吧!`, sender: 'ai' } ]);
|
||||
setMessages([{ id: `init-temp-${currentAssistantId}`, session_id: 'temp-new-chat', sender: 'ai', text: `开始与 ${currentAssistant?.name || '助手'} 的新对话吧!`, order: 0, created_at: new Date().toISOString() }]);
|
||||
}
|
||||
} catch (apiError: any) { toast.error(`加载会话列表失败: ${apiError.message}`); }
|
||||
finally { setSessionsLoading(false); }
|
||||
};
|
||||
loadSessions();
|
||||
}, [currentAssistantId]); // 空依赖数组,只在挂载时运行一次
|
||||
|
||||
// Load sessions when assistant changes
|
||||
useEffect(() => {
|
||||
if (!currentSessionId || currentSessionId === "temp-new-chat") {
|
||||
// If it's temp-new-chat, messages are already set or should be empty initially
|
||||
if (currentSessionId === "temp-new-chat" && messages.length === 0) {
|
||||
// Ensure initial message is set
|
||||
const currentAssistant = assistants.find(
|
||||
(a) => a.id === currentAssistantId
|
||||
);
|
||||
setMessages([
|
||||
{
|
||||
id: `init-temp-${currentAssistantId}`,
|
||||
session_id: "temp-new-chat",
|
||||
sender: "ai",
|
||||
text: `开始与 ${currentAssistant?.name || "助手"} 的新对话吧!`,
|
||||
order: 0,
|
||||
created_at: new Date().toISOString(),
|
||||
},
|
||||
]);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
const loadMessages = async () => {
|
||||
setMessagesLoading(true);
|
||||
setError(null); // Clear previous errors
|
||||
console.log(`加载会话 ${currentSessionId} 的消息...`);
|
||||
try {
|
||||
const fetchedMessages = await getMessagesBySession(currentSessionId);
|
||||
setMessages(fetchedMessages);
|
||||
console.log(`成功加载 ${fetchedMessages.length} 条消息`);
|
||||
} catch (apiError: any) {
|
||||
toast.error(`加载会话列表失败: ${apiError.message}`);
|
||||
toast.error(`加载消息失败: ${apiError.message}`);
|
||||
setError(`无法加载消息: ${apiError.message}`);
|
||||
setMessages([]); // Clear messages on error
|
||||
} finally {
|
||||
setSessionsLoading(false);
|
||||
setMessagesLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
loadSessions();
|
||||
}, [currentAssistantId, assistants]); // 依赖助手 ID 和助手列表 (以防助手信息更新)
|
||||
loadMessages();
|
||||
}, [currentSessionId]); // 依赖助手 ID 和助手列表 (以防助手信息更新)
|
||||
|
||||
// Auto scroll
|
||||
useEffect(() => {
|
||||
messagesEndRef.current?.scrollIntoView({ behavior: 'smooth' });
|
||||
messagesEndRef.current?.scrollIntoView({ behavior: "smooth" });
|
||||
}, [messages]);
|
||||
|
||||
// Filter sessions for the current assistant (UI display)
|
||||
const currentAssistantSessions = React.useMemo(() => {
|
||||
// 直接从 allSessions 过滤,因为加载时已经更新了
|
||||
return allSessions
|
||||
.filter(s => s.assistant_id === currentAssistantId)
|
||||
.sort((a, b) => new Date(b.created_at).getTime() - new Date(a.created_at).getTime()); // 按时间倒序
|
||||
.filter((s) => s.assistant_id === currentAssistantId)
|
||||
.sort(
|
||||
(a, b) =>
|
||||
new Date(b.created_at).getTime() - new Date(a.created_at).getTime()
|
||||
); // 按时间倒序
|
||||
}, [allSessions, currentAssistantId]);
|
||||
|
||||
// --- Assistant CRUD Handlers (Updated with API calls) ---
|
||||
@ -349,18 +440,14 @@ export default function ChatPage() {
|
||||
if (id) {
|
||||
// 编辑
|
||||
savedAssistant = await updateAssistant(id, data);
|
||||
setAssistants(prev => prev.map(a => (a.id === id ? savedAssistant : a)));
|
||||
setAssistants((prev) =>
|
||||
prev.map((a) => (a.id === id ? savedAssistant : a))
|
||||
);
|
||||
toast.success(`助手 "${savedAssistant.name}" 已更新`);
|
||||
// 如果更新的是当前助手,可能需要重新加载会话或消息
|
||||
if (id === currentAssistantId) {
|
||||
// 简单处理:可以强制刷新会话列表(或提示用户)
|
||||
setCurrentAssistantId(null); // 触发 useEffect 重新加载
|
||||
setTimeout(() => setCurrentAssistantId(id), 0);
|
||||
}
|
||||
} else {
|
||||
// 创建
|
||||
savedAssistant = await createAssistant(data);
|
||||
setAssistants(prev => [...prev, savedAssistant]);
|
||||
setAssistants((prev) => [...prev, savedAssistant]);
|
||||
toast.success(`助手 "${savedAssistant.name}" 已创建`);
|
||||
// 创建后自动选中
|
||||
handleSelectAssistant(savedAssistant.id);
|
||||
@ -372,7 +459,7 @@ export default function ChatPage() {
|
||||
};
|
||||
|
||||
const handleDeleteAssistant = async (idToDelete: string) => {
|
||||
if (idToDelete === 'asst-default' || assistants.length <= 1) {
|
||||
if (idToDelete === "asst-default" || assistants.length <= 1) {
|
||||
toast.error("不能删除默认助手或最后一个助手");
|
||||
return;
|
||||
}
|
||||
@ -381,12 +468,16 @@ export default function ChatPage() {
|
||||
return;
|
||||
}
|
||||
|
||||
const assistantToDelete = assistants.find(a => a.id === idToDelete);
|
||||
if (window.confirm(`确定要删除助手 "${assistantToDelete?.name}" 吗?相关会话也将被删除。`)) {
|
||||
const assistantToDelete = assistants.find((a) => a.id === idToDelete);
|
||||
if (
|
||||
window.confirm(
|
||||
`确定要删除助手 "${assistantToDelete?.name}" 吗?相关会话也将被删除。`
|
||||
)
|
||||
) {
|
||||
try {
|
||||
await deleteAssistant(idToDelete);
|
||||
// 后端应负责删除关联会话,前端只需更新助手列表
|
||||
setAssistants(prev => prev.filter(a => a.id !== idToDelete));
|
||||
setAssistants((prev) => prev.filter((a) => a.id !== idToDelete));
|
||||
// (可选) 如果需要立即清除前端的会话缓存
|
||||
// setAllSessions(prev => prev.filter(s => s.assistant_id !== idToDelete));
|
||||
toast.success(`助手 "${assistantToDelete?.name}" 已删除`);
|
||||
@ -406,77 +497,115 @@ export default function ChatPage() {
|
||||
setIsAssistantDialogOpen(true);
|
||||
};
|
||||
|
||||
// --- Send Message Handler (Updated with API response handling) ---
|
||||
// --- Send Message Handler (Updated - handles new session ID from response) ---
|
||||
const handleSendMessage = async (e?: React.FormEvent<HTMLFormElement>) => {
|
||||
e?.preventDefault();
|
||||
const trimmedMessage = inputMessage.trim();
|
||||
if (!trimmedMessage || isLoading || !currentSessionId || !currentAssistantId) return; // 增加检查
|
||||
if (
|
||||
!trimmedMessage ||
|
||||
isLoading ||
|
||||
!currentSessionId ||
|
||||
!currentAssistantId
|
||||
)
|
||||
return;
|
||||
setError(null);
|
||||
setIsLoading(true);
|
||||
setIsLoading(true); // Start loading (for AI reply)
|
||||
|
||||
const userMessage: Message = {
|
||||
id: Date.now().toString(),
|
||||
const tempUserMessageId = `temp-user-${Date.now()}`; // Temporary ID for optimistic update
|
||||
const userMessageOptimistic: Message = {
|
||||
id: tempUserMessageId,
|
||||
session_id:
|
||||
currentSessionId === "temp-new-chat" ? "temp" : currentSessionId, // Use temp session id if needed
|
||||
text: trimmedMessage,
|
||||
sender: 'user',
|
||||
sender: "user",
|
||||
order: (messages[messages.length - 1]?.order || 0) + 1, // Estimate order
|
||||
created_at: new Date().toISOString(),
|
||||
};
|
||||
// 立即显示用户消息
|
||||
setMessages(prev => [...prev, userMessage]);
|
||||
setInputMessage(''); // 清空输入框
|
||||
|
||||
// Optimistic UI update: Add user message immediately
|
||||
setMessages((prev) => [...prev, userMessageOptimistic]);
|
||||
setInputMessage("");
|
||||
|
||||
let targetSessionId = currentSessionId; // Will be updated if new session is created
|
||||
|
||||
try {
|
||||
// 调用后端 API
|
||||
const response: ChatApiResponse = await sendChatMessage(
|
||||
trimmedMessage,
|
||||
currentSessionId, // 发送当前 session ID ('temp-new-chat' 或真实 ID)
|
||||
currentSessionId, // Send 'temp-new-chat' or actual ID
|
||||
currentAssistantId
|
||||
);
|
||||
|
||||
// 处理 AI 回复
|
||||
// Process successful response
|
||||
const aiMessage: Message = {
|
||||
id: Date.now().toString() + '_ai',
|
||||
id: `ai-${Date.now()}`, // Use temporary or actual ID from backend if provided
|
||||
session_id: response.session_id || targetSessionId, // Use new session ID if available
|
||||
text: response.reply,
|
||||
sender: 'ai',
|
||||
sender: "ai",
|
||||
order: userMessageOptimistic.order + 1, // Estimate order
|
||||
created_at: new Date().toISOString(),
|
||||
};
|
||||
setMessages((prevMessages) => [...prevMessages, aiMessage]);
|
||||
|
||||
// 如果后端创建了新会话并返回了信息
|
||||
if (response.session_id && response.session_title && currentSessionId === 'temp-new-chat') {
|
||||
// Update messages: Replace temp user message with potential real one (if backend returned it)
|
||||
// and add AI message. For simplicity, we just add the AI message.
|
||||
// A more robust solution would involve matching IDs.
|
||||
setMessages((prev) => [
|
||||
...prev.filter((m) => m.id !== tempUserMessageId),
|
||||
userMessageOptimistic,
|
||||
aiMessage,
|
||||
]); // Keep optimistic user msg for now
|
||||
|
||||
// If a new session was created by the backend
|
||||
if (
|
||||
response.session_id &&
|
||||
response.session_title &&
|
||||
currentSessionId === "temp-new-chat"
|
||||
) {
|
||||
const newSession: Session = {
|
||||
id: response.session_id,
|
||||
title: response.session_title,
|
||||
assistant_id: currentAssistantId,
|
||||
created_at: new Date().toISOString(), // 使用客户端时间或后端返回的时间
|
||||
created_at: new Date().toISOString(), // Or use time from backend if available
|
||||
};
|
||||
// 更新全局会话列表和当前会话 ID
|
||||
setAllSessions(prev => [...prev, newSession]);
|
||||
setCurrentSessionId(newSession.id);
|
||||
console.log(`前端已更新新会话信息: ID=${newSession.id}, Title=${newSession.title}`);
|
||||
setAllSessions((prev) => [...prev, newSession]);
|
||||
setCurrentSessionId(newSession.id); // Switch to the new session ID
|
||||
// Update the session_id of the messages just added
|
||||
setMessages((prev) =>
|
||||
prev.map((m) =>
|
||||
m.session_id === "temp" ? { ...m, session_id: newSession.id } : m
|
||||
)
|
||||
);
|
||||
console.log(
|
||||
`前端已更新新会话信息: ID=${newSession.id}, Title=${newSession.title}`
|
||||
);
|
||||
}
|
||||
|
||||
} catch (apiError: any) {
|
||||
console.error("发送消息失败:", apiError);
|
||||
const errorMessageText = apiError.message || '发生未知错误';
|
||||
setError(errorMessageText);
|
||||
// Handle error: Remove optimistic user message and show error
|
||||
setMessages((prev) => prev.filter((m) => m.id !== tempUserMessageId));
|
||||
const errorMessageText = apiError.message || "发生未知错误";
|
||||
toast.error(`发送消息失败: ${errorMessageText}`);
|
||||
setError(`发送消息失败: ${errorMessageText}`);
|
||||
// Optionally add an error message to the chat
|
||||
const errorMessage: Message = {
|
||||
id: Date.now().toString() + '_err',
|
||||
/* ... */ id: `err-${Date.now()}`,
|
||||
session_id: targetSessionId,
|
||||
text: `错误: ${errorMessageText}`,
|
||||
sender: 'ai',
|
||||
isError: true,
|
||||
sender: "ai",
|
||||
order: userMessageOptimistic.order + 1,
|
||||
created_at: new Date().toISOString(),
|
||||
};
|
||||
setMessages((prevMessages) => [...prevMessages, errorMessage]);
|
||||
} finally {
|
||||
setIsLoading(false);
|
||||
setIsLoading(false); // Stop AI reply loading
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
// --- Other Handlers (基本不变, 但需要检查 currentAssistantId/currentSessionId 是否存在) ---
|
||||
const handleInputChange = (e: React.ChangeEvent<HTMLInputElement>) => {
|
||||
setInputMessage(e.target.value);
|
||||
};
|
||||
|
||||
const handleKeyDown = (e: React.KeyboardEvent<HTMLInputElement>) => {
|
||||
if (e.key === 'Enter' && !e.shiftKey && !isLoading) {
|
||||
if (e.key === "Enter" && !e.shiftKey && !isLoading) {
|
||||
e.preventDefault();
|
||||
handleSendMessage();
|
||||
}
|
||||
@ -487,55 +616,80 @@ export default function ChatPage() {
|
||||
};
|
||||
|
||||
const handleSelectAssistant = (assistantId: string) => {
|
||||
console.log("选择助手id",assistantId)
|
||||
if (assistantId !== currentAssistantId) {
|
||||
setCurrentAssistantId(assistantId); // 触发 useEffect 加载会话
|
||||
console.log("当前助手id",currentAssistantId)
|
||||
}
|
||||
}
|
||||
};
|
||||
// Updated handleSelectSession to just set the ID, useEffect handles loading
|
||||
|
||||
const handleSelectSession = (sessionId: string) => {
|
||||
if (sessionId !== currentSessionId) {
|
||||
setCurrentSessionId(sessionId);
|
||||
// TODO: 调用 API 加载该会话的历史消息
|
||||
console.log(`切换到会话: ${sessionId}`);
|
||||
const session = allSessions.find(s => s.id === sessionId);
|
||||
const assistant = assistants.find(a => a.id === session?.assistant_id);
|
||||
setMessages([
|
||||
{ id: `init-${sessionId}-1`, text: `继续与 ${assistant?.name || '助手'} 的对话: ${session?.title || ''}`, sender: 'ai' },
|
||||
// ... 加载真实历史消息
|
||||
]);
|
||||
}
|
||||
setCurrentSessionId(sessionId); // Trigger useEffect to load messages
|
||||
}
|
||||
};
|
||||
|
||||
const handleNewTopic = () => {
|
||||
if (currentSessionId !== 'temp-new-chat' && currentAssistantId) { // 确保有助手被选中
|
||||
setCurrentSessionId('temp-new-chat');
|
||||
const currentAssistant = assistants.find(a => a.id === currentAssistantId);
|
||||
setMessages([
|
||||
{ id: `init-temp-${currentAssistantId}`, text: `开始与 ${currentAssistant?.name || '助手'} 的新对话吧!`, sender: 'ai' },
|
||||
]);
|
||||
if (currentSessionId !== "temp-new-chat" && currentAssistantId) {
|
||||
// 确保有助手被选中
|
||||
setCurrentSessionId("temp-new-chat");
|
||||
console.log("手动创建临时新对话");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// --- JSX Rendering ---
|
||||
return (
|
||||
// 最外层 Flex 容器
|
||||
<div className="flex h-full gap-1"> {/* 使用 gap 添加间距 */}
|
||||
<div className="flex h-full gap-1">
|
||||
{" "}
|
||||
{/* 使用 gap 添加间距 */}
|
||||
<Toaster position="top-center" richColors />
|
||||
{/* 左侧助手面板 */}
|
||||
<aside className="w-64 bg-white dark:bg-gray-800 rounded-lg shadow-md p-4 flex flex-col">
|
||||
<h2 className="w-full text-lg font-semibold mb-4 text-gray-800 dark:text-gray-200 flex items-center justify-between">
|
||||
<span>助手列表</span>
|
||||
{/* 添加刷新按钮 */}
|
||||
<Button variant="ghost" size="icon" className="h-7 w-7" onClick={() => {
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
className="h-7 w-7"
|
||||
onClick={() => {
|
||||
// 重新加载助手列表
|
||||
const loadAssistants = async () => { /* ... */ }; // 将加载逻辑提取出来
|
||||
const loadAssistants = async () => {
|
||||
setAssistantsLoading(true);
|
||||
try {
|
||||
const fetchedAssistants = await getAssistants();
|
||||
setAssistants(fetchedAssistants);
|
||||
const defaultAssistant =
|
||||
fetchedAssistants.find((a) => a.id === "asst-default") ||
|
||||
fetchedAssistants[0];
|
||||
if (defaultAssistant) {
|
||||
setCurrentAssistantId(defaultAssistant.id);
|
||||
} else {
|
||||
console.warn("No default or initial assistant found.");
|
||||
}
|
||||
} catch (apiError: any) {
|
||||
toast.error(`加载助手列表失败: ${apiError.message}`);
|
||||
setError(`无法加载助手: ${apiError.message}`);
|
||||
} finally {
|
||||
setAssistantsLoading(false);
|
||||
}
|
||||
}; // 将加载逻辑提取出来
|
||||
loadAssistants();
|
||||
}} disabled={assistantsLoading}>
|
||||
<RefreshCw size={16} className={assistantsLoading ? 'animate-spin' : ''}/>
|
||||
}}
|
||||
disabled={assistantsLoading}
|
||||
>
|
||||
<RefreshCw
|
||||
size={16}
|
||||
className={assistantsLoading ? "animate-spin" : ""}
|
||||
/>
|
||||
</Button>
|
||||
</h2>
|
||||
<Dialog open={isAssistantDialogOpen} onOpenChange={setIsAssistantDialogOpen}>
|
||||
<Dialog
|
||||
open={isAssistantDialogOpen}
|
||||
onOpenChange={setIsAssistantDialogOpen}
|
||||
>
|
||||
<DialogTrigger asChild>
|
||||
<Button
|
||||
variant="default" // 使用 shadcn Button
|
||||
@ -547,28 +701,39 @@ export default function ChatPage() {
|
||||
</Button>
|
||||
</DialogTrigger>
|
||||
{/* Dialog 内容 */}
|
||||
<DialogContent className="sm:max-w-[600px]"> {/* 调整宽度 */}
|
||||
<DialogContent className="sm:max-w-[600px]">
|
||||
{" "}
|
||||
{/* 调整宽度 */}
|
||||
<DialogHeader>
|
||||
<DialogTitle>{editingAssistant ? '编辑助手' : '创建新助手'}</DialogTitle>
|
||||
<DialogTitle>
|
||||
{editingAssistant ? "编辑助手" : "创建新助手"}
|
||||
</DialogTitle>
|
||||
<DialogDescription>
|
||||
{editingAssistant ? '修改助手的配置信息。' : '定义一个新助手的名称、行为和参数。'}
|
||||
{editingAssistant
|
||||
? "修改助手的配置信息。"
|
||||
: "定义一个新助手的名称、行为和参数。"}
|
||||
</DialogDescription>
|
||||
</DialogHeader>
|
||||
{/* 渲染助手表单 */}
|
||||
<AssistantForm
|
||||
key={editingAssistant?.id || 'create'} // 添加 key 确保编辑时表单重置
|
||||
key={editingAssistant?.id || "create"} // 添加 key 确保编辑时表单重置
|
||||
assistant={editingAssistant}
|
||||
onSave={handleSaveAssistant}
|
||||
onClose={() => setIsAssistantDialogOpen(false)} // 传递关闭回调
|
||||
/>
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
<div className="flex-1 overflow-y-auto space-y-2 pr-1"> {/* 添加右内边距防止滚动条遮挡 */}
|
||||
<div className="flex-1 overflow-y-auto space-y-2 pr-1">
|
||||
{" "}
|
||||
{/* 添加右内边距防止滚动条遮挡 */}
|
||||
{/* 渲染助手列表 */}
|
||||
{assistantsLoading ? (
|
||||
// 显示骨架屏
|
||||
Array.from({ length: 3 }).map((_, index) => (
|
||||
<div key={index} className="p-3 rounded-lg flex items-center gap-3">
|
||||
<div
|
||||
key={index}
|
||||
className="p-3 rounded-lg flex items-center gap-3"
|
||||
>
|
||||
<Skeleton className="h-8 w-8 rounded-full" />
|
||||
<div className="flex-1 space-y-1">
|
||||
<Skeleton className="h-4 w-3/4" />
|
||||
@ -577,24 +742,34 @@ export default function ChatPage() {
|
||||
</div>
|
||||
))
|
||||
) : assistants.length === 0 ? (
|
||||
<p className="text-center text-sm text-gray-500 dark:text-gray-400 mt-4">没有找到助手。</p>
|
||||
<p className="text-center text-sm text-gray-500 dark:text-gray-400 mt-4">
|
||||
没有找到助手。
|
||||
</p>
|
||||
) : (
|
||||
// 渲染助手列表
|
||||
assistants.map(assistant => (
|
||||
assistants.map((assistant) => (
|
||||
<div
|
||||
key={assistant.id}
|
||||
onClick={() => handleSelectAssistant(assistant.id)}
|
||||
className={`group p-2 rounded-lg cursor-pointer flex items-center gap-3 relative ${
|
||||
currentAssistantId === assistant.id
|
||||
? 'bg-red-100 dark:bg-red-900/50 ring-1 ring-red-300 dark:ring-red-700'
|
||||
: 'text-gray-700 dark:text-gray-300 hover:bg-gray-100 dark:hover:bg-gray-700'
|
||||
? "bg-red-100 dark:bg-red-900/50 ring-1 ring-red-300 dark:ring-red-700"
|
||||
: "text-gray-700 dark:text-gray-300 hover:bg-gray-100 dark:hover:bg-gray-700"
|
||||
}`}
|
||||
title={assistant.description || ''}
|
||||
title={assistant.description || ""}
|
||||
>
|
||||
{/* ... 助手头像和名称 ... */}
|
||||
<span className="text-lg flex-shrink-0 w-6 text-center">{assistant.avatar || '👤'}</span>
|
||||
<span className="text-lg flex-shrink-0 w-6 text-center">
|
||||
{assistant.avatar || "👤"}
|
||||
</span>
|
||||
<div className="flex-1 overflow-hidden">
|
||||
<p className={`text-sm font-medium truncate ${currentAssistantId === assistant.id ? 'text-red-800 dark:text-red-200' : 'text-gray-800 dark:text-gray-200'}`}>
|
||||
<p
|
||||
className={`text-sm font-medium truncate ${
|
||||
currentAssistantId === assistant.id
|
||||
? "text-red-800 dark:text-red-200"
|
||||
: "text-gray-800 dark:text-gray-200"
|
||||
}`}
|
||||
>
|
||||
{assistant.name}
|
||||
</p>
|
||||
</div>
|
||||
@ -604,17 +779,23 @@ export default function ChatPage() {
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
className="h-6 w-6 text-blue-600 hover:text-blue-800 dark:text-blue-400 dark:hover:text-blue-300"
|
||||
onClick={(e) => { e.stopPropagation(); handleEditAssistant(assistant); }} // 阻止事件冒泡
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
handleEditAssistant(assistant);
|
||||
}} // 阻止事件冒泡
|
||||
title="编辑助手"
|
||||
>
|
||||
<Edit size={14} />
|
||||
</Button>
|
||||
{assistant.id !== 'asst-default' && ( // 不显示默认助手的删除按钮
|
||||
{assistant.id !== "asst-default" && ( // 不显示默认助手的删除按钮
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
className="h-6 w-6 text-red-500 hover:text-red-700 dark:text-red-400 dark:hover:text-red-300"
|
||||
onClick={(e) => { e.stopPropagation(); handleDeleteAssistant(assistant.id); }}
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
handleDeleteAssistant(assistant.id);
|
||||
}}
|
||||
title="删除助手"
|
||||
>
|
||||
<Trash2 size={14} />
|
||||
@ -626,18 +807,26 @@ export default function ChatPage() {
|
||||
)}
|
||||
</div>
|
||||
</aside>
|
||||
|
||||
{/* 中间主聊天区域 */}
|
||||
<div className="flex flex-col flex-1 bg-white dark:bg-gray-800 rounded-lg shadow-md overflow-hidden">
|
||||
{/* 聊天窗口标题 - 显示当前助手和切换会话按钮 */}
|
||||
<div className="flex justify-between items-center p-4 border-b dark:border-gray-700">
|
||||
{currentAssistantId ? (
|
||||
<div className="flex items-center gap-2">
|
||||
<span className="text-xl">{assistants.find(a => a.id === currentAssistantId)?.avatar || '👤'}</span>
|
||||
<span className="text-xl">
|
||||
{assistants.find((a) => a.id === currentAssistantId)?.avatar ||
|
||||
"👤"}
|
||||
</span>
|
||||
<h1 className="text-lg font-semibold text-gray-800 dark:text-gray-200">
|
||||
{assistants.find(a => a.id === currentAssistantId)?.name || '加载中...'}
|
||||
{assistants.find((a) => a.id === currentAssistantId)?.name ||
|
||||
"加载中..."}
|
||||
<span className="text-sm font-normal text-gray-500 dark:text-gray-400 ml-2">
|
||||
({currentSessionId === 'temp-new-chat' ? '新话题' : allSessions.find(s => s.id === currentSessionId)?.title || (sessionsLoading ? '加载中...' : '选择话题')})
|
||||
(
|
||||
{currentSessionId === "temp-new-chat"
|
||||
? "新话题"
|
||||
: allSessions.find((s) => s.id === currentSessionId)
|
||||
?.title || (sessionsLoading ? "加载中..." : "选择话题")}
|
||||
)
|
||||
</span>
|
||||
</h1>
|
||||
</div>
|
||||
@ -649,41 +838,64 @@ export default function ChatPage() {
|
||||
className="p-1 rounded text-gray-500 dark:text-gray-400 hover:bg-gray-200 dark:hover:bg-gray-700"
|
||||
title={isSessionPanelOpen ? "关闭会话面板" : "打开会话面板"}
|
||||
>
|
||||
{isSessionPanelOpen ? <PanelRightClose size={20} /> : <PanelRightOpen size={20} />}
|
||||
{isSessionPanelOpen ? (
|
||||
<PanelRightClose size={20} />
|
||||
) : (
|
||||
<PanelRightOpen size={20} />
|
||||
)}
|
||||
</button>
|
||||
</div>
|
||||
|
||||
{/* 消息显示区域 */}
|
||||
<div className="flex-1 overflow-y-auto p-4 space-y-4">
|
||||
{/* 可以添加一个全局错误提示 */}
|
||||
{error && <p className="text-center text-sm text-red-500 dark:text-red-400">{error}</p>}
|
||||
{messages.length === 0 && !isLoading && !sessionsLoading && currentSessionId !== 'temp-new-chat' && (
|
||||
<p className="text-center text-sm text-gray-500 dark:text-gray-400 mt-8">选择一个话题开始聊天,或新建一个话题。</p>
|
||||
{error && (
|
||||
<p className="text-center text-sm text-red-500 dark:text-red-400">
|
||||
{error}
|
||||
</p>
|
||||
)}
|
||||
{messages.map((message) => (
|
||||
{messagesLoading ? (
|
||||
// Message loading skeleton
|
||||
<div className="space-y-4">
|
||||
<Skeleton className="h-10 w-3/5 rounded-lg" />
|
||||
<Skeleton className="h-12 w-4/5 ml-auto rounded-lg" />
|
||||
<Skeleton className="h-8 w-1/2 rounded-lg" />
|
||||
</div>
|
||||
) : messages.length === 0 && currentSessionId !== "temp-new-chat" ? (
|
||||
<p className="text-center text-sm text-gray-500 dark:text-gray-400 mt-8">
|
||||
{currentAssistantId
|
||||
? "选择或新建一个话题开始聊天。"
|
||||
: "请先选择一个助手。"}
|
||||
</p>
|
||||
) : (
|
||||
// Render actual messages
|
||||
messages.map((message) => (
|
||||
<div
|
||||
key={message.id}
|
||||
key={message.id} // Use message ID from DB
|
||||
className={`flex ${
|
||||
message.sender === 'user' ? 'justify-end' : 'justify-start'
|
||||
message.sender === "user" ? "justify-end" : "justify-start"
|
||||
}`}
|
||||
>
|
||||
<div
|
||||
className={`max-w-xs md:max-w-md lg:max-w-lg px-4 py-2 rounded-lg shadow ${
|
||||
message.sender === 'user'
|
||||
? 'bg-red-500 text-white'
|
||||
: message.isError
|
||||
? 'bg-red-100 dark:bg-red-900/50 text-red-700 dark:text-red-300'
|
||||
: 'bg-gray-200 dark:bg-gray-700 text-gray-800 dark:text-gray-200'
|
||||
message.sender === "user"
|
||||
? "bg-red-500 text-white"
|
||||
: message.isError // Check if it's an error message added by frontend
|
||||
? "bg-red-100 dark:bg-red-900/50 text-red-700 dark:text-red-300"
|
||||
: "bg-gray-200 dark:bg-gray-700 text-gray-800 dark:text-gray-200"
|
||||
}`}
|
||||
>
|
||||
<p className="text-sm whitespace-pre-wrap">{message.text}</p>
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
))
|
||||
)}
|
||||
{isLoading && (
|
||||
<div className="flex justify-center items-center p-2">
|
||||
<Loader2 className="h-5 w-5 animate-spin text-gray-500 dark:text-gray-400" />
|
||||
<span className="ml-2 text-sm text-gray-500 dark:text-gray-400">AI 正在思考...</span>
|
||||
<span className="ml-2 text-sm text-gray-500 dark:text-gray-400">
|
||||
AI 正在思考...
|
||||
</span>
|
||||
</div>
|
||||
)}
|
||||
<div ref={messagesEndRef} />
|
||||
@ -691,36 +903,70 @@ export default function ChatPage() {
|
||||
|
||||
{/* 消息输入区域 */}
|
||||
<div className="p-4 border-t dark:border-gray-700">
|
||||
<form onSubmit={handleSendMessage} className="flex items-center space-x-2">
|
||||
<form
|
||||
onSubmit={handleSendMessage}
|
||||
className="flex items-center space-x-2"
|
||||
>
|
||||
<input
|
||||
type="text"
|
||||
value={inputMessage}
|
||||
onChange={handleInputChange}
|
||||
onKeyDown={handleKeyDown}
|
||||
placeholder={isLoading ? "AI 正在回复..." : "输入你的消息..."}
|
||||
disabled={isLoading || sessionsLoading || !currentAssistantId || !currentSessionId} // 添加禁用条件
|
||||
disabled={
|
||||
isLoading ||
|
||||
messagesLoading ||
|
||||
sessionsLoading ||
|
||||
!currentAssistantId ||
|
||||
!currentSessionId
|
||||
} // 添加禁用条件
|
||||
className="flex-1 px-4 py-2 border rounded-lg focus:outline-none focus:ring-2 focus:ring-red-500 dark:bg-gray-700 dark:border-gray-600 dark:text-gray-200 dark:focus:ring-red-600 disabled:opacity-70 transition-opacity"
|
||||
aria-label="聊天输入框"
|
||||
/>
|
||||
<Button
|
||||
type="submit"
|
||||
disabled={!inputMessage.trim() || isLoading || sessionsLoading || !currentAssistantId || !currentSessionId} // 添加禁用条件
|
||||
disabled={
|
||||
!inputMessage.trim() ||
|
||||
isLoading ||
|
||||
messagesLoading ||
|
||||
sessionsLoading ||
|
||||
!currentAssistantId ||
|
||||
!currentSessionId
|
||||
} // 添加禁用条件
|
||||
className="p-2 rounded-lg bg-red-500 text-white hover:bg-red-600 focus:outline-none focus:ring-2 focus:ring-red-500 focus:ring-offset-2 dark:focus:ring-offset-gray-800 disabled:opacity-50 disabled:cursor-not-allowed transition-all flex items-center justify-center h-10 w-10"
|
||||
aria-label={isLoading ? "正在发送" : "发送消息"}
|
||||
>
|
||||
{isLoading ? <Loader2 className="h-5 w-5 animate-spin" /> : <SendHorizontal size={20} />}
|
||||
{isLoading ? (
|
||||
<Loader2 className="h-5 w-5 animate-spin" />
|
||||
) : (
|
||||
<SendHorizontal size={20} />
|
||||
)}
|
||||
</Button>
|
||||
</form>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* 右侧会话管理面板 */}
|
||||
<aside className={`bg-white dark:bg-gray-800 rounded-lg shadow-md p-2 flex flex-col transition-all duration-300 ease-in-out ${isSessionPanelOpen ? 'w-64' : 'w-0 p-0 border-0 overflow-hidden opacity-0'}`}> {/* 调整关闭时的样式 */}
|
||||
<h2 className="text-lg font-semibold mb-4 text-gray-800 dark:text-gray-200 whitespace-nowrap items-center justify-center">话题列表</h2> {/* 改为话题 */}
|
||||
<aside
|
||||
className={`bg-white dark:bg-gray-800 rounded-lg shadow-md p-2 flex flex-col transition-all duration-300 ease-in-out ${
|
||||
isSessionPanelOpen
|
||||
? "w-64"
|
||||
: "w-0 p-0 border-0 overflow-hidden opacity-0"
|
||||
}`}
|
||||
>
|
||||
{" "}
|
||||
{/* 调整关闭时的样式 */}
|
||||
<h2 className="text-lg font-semibold mb-4 text-gray-800 dark:text-gray-200 whitespace-nowrap items-center justify-center">
|
||||
话题列表
|
||||
</h2>{" "}
|
||||
{/* 改为话题 */}
|
||||
<Button
|
||||
onClick={handleNewTopic} // 绑定新建话题事件
|
||||
className="mb-4 w-full px-3 py-2 bg-red-500 text-white rounded-lg hover:bg-red-600 transition-colors text-sm whitespace-nowrap flex items-center justify-center gap-2 disabled:opacity-50"
|
||||
disabled={currentSessionId === 'temp-new-chat' || sessionsLoading || !currentAssistantId} // 添加禁用条件
|
||||
disabled={
|
||||
currentSessionId === "temp-new-chat" ||
|
||||
sessionsLoading ||
|
||||
!currentAssistantId
|
||||
} // 添加禁用条件
|
||||
>
|
||||
+ 新建话题
|
||||
</Button>
|
||||
@ -730,19 +976,22 @@ export default function ChatPage() {
|
||||
Array.from({ length: 5 }).map((_, index) => (
|
||||
<Skeleton key={index} className="h-8 w-full my-1.5 rounded-lg" />
|
||||
))
|
||||
) : currentAssistantSessions.length === 0 && currentSessionId !== 'temp-new-chat' ? (
|
||||
<p className="text-center text-sm text-gray-500 dark:text-gray-400 mt-4">没有历史话题。</p>
|
||||
) : currentAssistantSessions.length === 0 &&
|
||||
currentSessionId !== "temp-new-chat" ? (
|
||||
<p className="text-center text-sm text-gray-500 dark:text-gray-400 mt-4">
|
||||
没有历史话题。
|
||||
</p>
|
||||
) : (
|
||||
<>
|
||||
{/* 渲染会话列表 */}
|
||||
{currentAssistantSessions.map(session => (
|
||||
{currentAssistantSessions.map((session) => (
|
||||
<div
|
||||
key={session.id}
|
||||
onClick={() => handleSelectSession(session.id)}
|
||||
className={`p-2 rounded-lg cursor-pointer text-sm truncate whitespace-nowrap ${
|
||||
currentSessionId === session.id
|
||||
? 'bg-red-100 dark:bg-red-900/50 text-red-700 dark:text-red-400 font-medium'
|
||||
: 'text-gray-700 dark:text-gray-300 hover:bg-gray-100 dark:hover:bg-gray-700'
|
||||
? "bg-red-100 dark:bg-red-900/50 text-red-700 dark:text-red-400 font-medium"
|
||||
: "text-gray-700 dark:text-gray-300 hover:bg-gray-100 dark:hover:bg-gray-700"
|
||||
}`}
|
||||
title={session.title}
|
||||
>
|
||||
@ -751,7 +1000,7 @@ export default function ChatPage() {
|
||||
</div>
|
||||
))}
|
||||
{/* 新话题占位符 */}
|
||||
{currentSessionId === 'temp-new-chat' && (
|
||||
{currentSessionId === "temp-new-chat" && (
|
||||
<div className="p-2 rounded-lg text-sm truncate whitespace-nowrap bg-red-100 dark:bg-red-900/50 text-red-700 dark:text-red-400 font-medium">
|
||||
新话题...
|
||||
</div>
|
||||
@ -760,7 +1009,6 @@ export default function ChatPage() {
|
||||
)}
|
||||
</div>
|
||||
</aside>
|
||||
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@ -45,7 +45,7 @@ export default function RootLayout({
|
||||
<html lang="zh-CN">
|
||||
<body className={`${geistSans.variable} ${geistMono.variable} antialiased flex h-screen bg-gray-100 dark:bg-gray-900`}>
|
||||
{/* 侧边栏导航 */}
|
||||
<aside className="w-16 bg-white dark:bg-gray-800 p-3 shadow-md flex flex-col items-center mr-1"> {/* 调整内边距和对齐 */}
|
||||
<aside className="w-16 bg-white dark:bg-gray-800 p-3 shadow-md flex flex-col items-center rounded-lg"> {/* 调整内边距和对齐 */}
|
||||
{/* Logo */}
|
||||
<div className="mb-6"> {/* 调整 Logo 边距 */}
|
||||
<Link href="/" className="flex items-center justify-center text-3xl font-bold text-red-600 dark:text-red-500" title="CherryAI 主页">
|
||||
@ -60,12 +60,11 @@ export default function RootLayout({
|
||||
<Link
|
||||
href={item.href}
|
||||
className="relative flex items-center justify-center p-2 rounded-lg text-gray-600 dark:text-gray-400 hover:bg-red-100 dark:hover:bg-red-900/50 hover:text-red-700 dark:hover:text-red-400 transition-colors duration-200 group" // 居中图标
|
||||
title={item.name} // 保留原生 title
|
||||
>
|
||||
<item.icon className="h-6 w-6 flex-shrink-0 text-gray-600 dark:text-gray-400" />
|
||||
{/* Tooltip 文字标签 */}
|
||||
<span
|
||||
className="absolute left-full top-1/2 -translate-y-1/2 ml-3 px-2 py-1 bg-gray-900 dark:bg-gray-700 text-white text-xl rounded shadow-lg opacity-0 group-hover:opacity-100 transition-opacity duration-200 delay-150 whitespace-nowrap pointer-events-none" // 使用 pointer-events-none 避免干扰悬浮
|
||||
className="z-10 absolute left-full top-1/2 -translate-y-1/2 ml-3 px-2 py-1 bg-gray-900 dark:bg-gray-700 text-white text-xl rounded shadow-lg opacity-0 group-hover:opacity-100 transition-opacity duration-200 delay-150 whitespace-nowrap pointer-events-none" // 使用 pointer-events-none 避免干扰悬浮
|
||||
>
|
||||
{item.name}
|
||||
</span>
|
||||
|
||||
@ -1,40 +1,28 @@
|
||||
// File: frontend/lib/api.ts (更新)
|
||||
// Description: 添加调用助手和会话管理 API 的函数
|
||||
|
||||
import { Assistant, AssistantCreateData, AssistantUpdateData } from "@/types/assistant";
|
||||
import axios from "axios";
|
||||
|
||||
// --- Types (从后端模型同步或手动定义) ---
|
||||
// 这些类型应该与后端 pydantic_models.py 中的 Read 模型匹配
|
||||
export interface Assistant {
|
||||
id: string;
|
||||
name: string;
|
||||
description?: string | null;
|
||||
avatar?: string | null;
|
||||
system_prompt: string;
|
||||
model: string;
|
||||
temperature: number;
|
||||
}
|
||||
|
||||
// --- Types ---
|
||||
export interface Session {
|
||||
id: string;
|
||||
title: string;
|
||||
assistant_id: string;
|
||||
created_at: string; // ISO date string
|
||||
updated_at?: string | null; // Add updated_at
|
||||
}
|
||||
|
||||
// 创建助手时发送的数据类型
|
||||
export interface AssistantCreateData {
|
||||
name: string;
|
||||
description?: string | null;
|
||||
avatar?: string | null;
|
||||
system_prompt: string;
|
||||
model: string;
|
||||
temperature: number;
|
||||
// Message type from backend
|
||||
export interface Message {
|
||||
id: string;
|
||||
session_id: string;
|
||||
sender: 'user' | 'ai'; // Or extend with 'system' if needed
|
||||
text: string;
|
||||
order: number;
|
||||
created_at: string; // ISO date string
|
||||
}
|
||||
|
||||
// 更新助手时发送的数据类型 (所有字段可选)
|
||||
export type AssistantUpdateData = Partial<AssistantCreateData>;
|
||||
|
||||
// 聊天响应类型
|
||||
export interface ChatApiResponse {
|
||||
reply: string;
|
||||
@ -169,4 +157,19 @@ export const deleteSession = async (sessionId: string): Promise<void> => {
|
||||
//当前端发送 sessionId 为 'temp-new-chat' 的消息时,后端会自动创建。
|
||||
//如果需要单独创建会话(例如,不发送消息就创建),则需要单独实现前端调用 POST /sessions/。
|
||||
|
||||
// TODO: 添加获取会话消息的 API 函数 (GET /sessions/{session_id}/messages)
|
||||
// --- Message API (New) ---
|
||||
/** 获取指定会话的消息列表 */
|
||||
export const getMessagesBySession = async (sessionId: string, limit: number = 100, skip: number = 0): Promise<Message[]> => {
|
||||
try {
|
||||
const response = await apiClient.get<Message[]>(`/messages/session/${sessionId}`, {
|
||||
params: { limit, skip }
|
||||
});
|
||||
return response.data;
|
||||
} catch (error) {
|
||||
// Handle 404 specifically if needed (session exists but no messages)
|
||||
if (axios.isAxiosError(error) && error.response?.status === 404) {
|
||||
return []; // Return empty list if session not found or no messages
|
||||
}
|
||||
throw new Error(handleApiError(error, 'getMessagesBySession'));
|
||||
}
|
||||
};
|
||||
|
||||
23
frontend/types/assistant.ts
Normal file
23
frontend/types/assistant.ts
Normal file
@ -0,0 +1,23 @@
|
||||
// --- Types (从后端模型同步或手动定义) ---
|
||||
// 这些类型应该与后端 pydantic_models.py 中的 Read 模型匹配
|
||||
export interface Assistant {
|
||||
id: string;
|
||||
name: string;
|
||||
description?: string | null;
|
||||
avatar?: string | null;
|
||||
system_prompt: string;
|
||||
model: string;
|
||||
temperature: number;
|
||||
}
|
||||
// 创建助手时发送的数据类型
|
||||
export interface AssistantCreateData {
|
||||
name: string;
|
||||
description?: string | null;
|
||||
avatar?: string | null;
|
||||
system_prompt: string;
|
||||
model: string;
|
||||
temperature: number;
|
||||
}
|
||||
|
||||
// 更新助手时发送的数据类型 (所有字段可选)
|
||||
export type AssistantUpdateData = Partial<AssistantCreateData>;
|
||||
Loading…
x
Reference in New Issue
Block a user