添加数据管理

This commit is contained in:
adrian 2025-04-30 04:39:36 +08:00
parent 7df10e82be
commit f0863914c2
18 changed files with 1108 additions and 676 deletions

View File

@ -1,12 +1,12 @@
# File: backend/app/api/v1/api.py (更新) # File: backend/app/api/v1/api.py (Update)
# Description: 聚合 v1 版本的所有 API 路由 # Description: 聚合 v1 版本的所有 API 路由
from fastapi import APIRouter from fastapi import APIRouter
from app.api.v1.endpoints import chat, assistants, sessions # 导入新路由 from app.api.v1.endpoints import chat, assistants, sessions, messages # Import messages router
api_router = APIRouter() api_router = APIRouter()
api_router.include_router(chat.router, prefix="/chat", tags=["Chat"]) api_router.include_router(chat.router, prefix="/chat", tags=["Chat"])
api_router.include_router(assistants.router, prefix="/assistants", tags=["Assistants"]) # 添加助手路由 api_router.include_router(assistants.router, prefix="/assistants", tags=["Assistants"])
api_router.include_router(sessions.router, prefix="/sessions", tags=["Sessions"]) # 添加会话路由 api_router.include_router(sessions.router, prefix="/sessions", tags=["Sessions"])
api_router.include_router(messages.router, prefix="/messages", tags=["Messages"]) # Add messages router

View File

@ -1,39 +1,39 @@
# File: backend/app/api/v1/endpoints/assistants.py (新建) # File: backend/app/api/v1/endpoints/assistants.py (Update with DB session dependency)
# Description: 助手的 API 路由 # Description: 助手的 API 路由 (使用数据库会话)
from fastapi import APIRouter, HTTPException, Depends, status from fastapi import APIRouter, HTTPException, Depends, status
from typing import List from typing import List
from sqlalchemy.ext.asyncio import AsyncSession
from app.db.database import get_db_session # Import DB session dependency
from app.models.pydantic_models import AssistantRead, AssistantCreate, AssistantUpdate from app.models.pydantic_models import AssistantRead, AssistantCreate, AssistantUpdate
from app.services.assistant_service import assistant_service_instance, AssistantService from app.services.assistant_service import AssistantService # Import the class
router = APIRouter() router = APIRouter()
# --- 依赖注入 AssistantService --- # --- Dependency Injection for Service and DB Session ---
def get_assistant_service() -> AssistantService: # Service instance can be created per request or globally
return assistant_service_instance # For simplicity, let's create it here, but pass db session to methods
assistant_service = AssistantService()
@router.post("/", response_model=AssistantRead, status_code=status.HTTP_201_CREATED) @router.post("/", response_model=AssistantRead, status_code=status.HTTP_201_CREATED)
async def create_new_assistant( async def create_new_assistant(
assistant_data: AssistantCreate, assistant_data: AssistantCreate,
service: AssistantService = Depends(get_assistant_service) db: AsyncSession = Depends(get_db_session) # Inject DB session
): ):
"""创建新助手""" return await assistant_service.create_assistant(db, assistant_data)
return service.create_assistant(assistant_data)
@router.get("/", response_model=List[AssistantRead]) @router.get("/", response_model=List[AssistantRead])
async def read_all_assistants( async def read_all_assistants(
service: AssistantService = Depends(get_assistant_service) db: AsyncSession = Depends(get_db_session)
): ):
"""获取所有助手列表""" return await assistant_service.get_assistants(db)
return service.get_assistants()
@router.get("/{assistant_id}", response_model=AssistantRead) @router.get("/{assistant_id}", response_model=AssistantRead)
async def read_assistant_by_id( async def read_assistant_by_id(
assistant_id: str, assistant_id: str,
service: AssistantService = Depends(get_assistant_service) db: AsyncSession = Depends(get_db_session)
): ):
"""根据 ID 获取特定助手""" assistant = await assistant_service.get_assistant(db, assistant_id)
assistant = service.get_assistant(assistant_id)
if not assistant: if not assistant:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="找不到指定的助手") raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="找不到指定的助手")
return assistant return assistant
@ -42,10 +42,9 @@ async def read_assistant_by_id(
async def update_existing_assistant( async def update_existing_assistant(
assistant_id: str, assistant_id: str,
assistant_data: AssistantUpdate, assistant_data: AssistantUpdate,
service: AssistantService = Depends(get_assistant_service) db: AsyncSession = Depends(get_db_session)
): ):
"""更新指定 ID 的助手""" updated_assistant = await assistant_service.update_assistant(db, assistant_id, assistant_data)
updated_assistant = service.update_assistant(assistant_id, assistant_data)
if not updated_assistant: if not updated_assistant:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="找不到指定的助手") raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="找不到指定的助手")
return updated_assistant return updated_assistant
@ -53,14 +52,17 @@ async def update_existing_assistant(
@router.delete("/{assistant_id}", status_code=status.HTTP_204_NO_CONTENT) @router.delete("/{assistant_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_existing_assistant( async def delete_existing_assistant(
assistant_id: str, assistant_id: str,
service: AssistantService = Depends(get_assistant_service) db: AsyncSession = Depends(get_db_session)
): ):
"""删除指定 ID 的助手""" # Handle potential error from service if trying to delete default
deleted = service.delete_assistant(assistant_id) try:
if not deleted: deleted = await assistant_service.delete_assistant(db, assistant_id)
# 根据服务层逻辑判断是找不到还是不允许删除 if not deleted:
assistant = service.get_assistant(assistant_id) # Check if it exists to differentiate 404 from 403 (or handle in service)
if assistant and assistant_id == 'asst-default': assistant = await assistant_service.get_assistant(db, assistant_id)
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="不允许删除默认助手") if assistant and assistant_id == 'asst-default':
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="找不到指定的助手") raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="不允许删除默认助手")
# 成功删除,不返回内容 raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="找不到指定的助手")
except Exception as e: # Catch other potential DB errors
print(f"删除助手时出错: {e}")
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="删除助手失败")

View File

@ -1,27 +1,26 @@
# File: backend/app/api/v1/endpoints/chat.py (更新) # File: backend/app/api/v1/endpoints/chat.py (Update with DB session dependency)
# Description: 聊天功能的 API 路由 (使用更新后的 ChatService) # Description: 聊天功能的 API 路由 (使用数据库会话)
from fastapi import APIRouter, HTTPException, Depends, status from fastapi import APIRouter, HTTPException, Depends, status
from sqlalchemy.ext.asyncio import AsyncSession
from app.db.database import get_db_session
from app.models.pydantic_models import ChatRequest, ChatResponse, SessionCreateRequest from app.models.pydantic_models import ChatRequest, ChatResponse, SessionCreateRequest
from app.services.chat_service import chat_service_instance, ChatService from app.services.chat_service import ChatService # Import class
from app.services.session_service import session_service_instance, SessionService # 导入 SessionService from app.services.session_service import SessionService # Import class
import app.core.config as Config # Import API Key for ChatService instantiation
router = APIRouter() router = APIRouter()
# --- 依赖注入 --- # --- Dependency Injection ---
def get_chat_service() -> ChatService: # Instantiate services here or use a more sophisticated dependency injection system
return chat_service_instance chat_service = ChatService(default_api_key=Config.GOOGLE_API_KEY)
session_service = SessionService()
def get_session_service() -> SessionService:
return session_service_instance
@router.post("/", response_model=ChatResponse) @router.post("/", response_model=ChatResponse)
async def handle_chat_message( async def handle_chat_message(
request: ChatRequest, request: ChatRequest,
chat_service: ChatService = Depends(get_chat_service), db: AsyncSession = Depends(get_db_session) # Inject DB session
session_service: SessionService = Depends(get_session_service) # 注入 SessionService
): ):
"""处理用户发送的聊天消息 (包含 assistantId 和 sessionId)"""
user_message = request.message user_message = request.message
session_id = request.session_id session_id = request.session_id
assistant_id = request.assistant_id assistant_id = request.assistant_id
@ -31,38 +30,39 @@ async def handle_chat_message(
response_session_id = None response_session_id = None
response_session_title = None response_session_title = None
# --- 处理临时新会话 ---
if session_id == 'temp-new-chat': if session_id == 'temp-new-chat':
print("检测到临时新会话,正在创建...") print("检测到临时新会话,正在创建...")
try: try:
# 调用 SessionService 创建会话
create_req = SessionCreateRequest(assistant_id=assistant_id, first_message=user_message) create_req = SessionCreateRequest(assistant_id=assistant_id, first_message=user_message)
created_session = await session_service.create_session(create_req) # Pass db session to the service method
session_id = created_session.id # 使用新创建的会话 ID created_session = await session_service.create_session(db, create_req)
response_session_id = created_session.id # 准备在响应中返回新 ID session_id = created_session.id
response_session_title = created_session.title # 准备在响应中返回新标题 response_session_id = created_session.id
response_session_title = created_session.title
print(f"新会话已创建: ID={session_id}, Title='{created_session.title}'") print(f"新会话已创建: ID={session_id}, Title='{created_session.title}'")
except ValueError as e: # 助手不存在等错误 except ValueError as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
except Exception as e: # LLM 调用或其他错误 except Exception as e:
print(f"创建会话时出错: {e}") print(f"创建会话时出错: {e}")
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="创建会话失败") raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="创建会话失败")
# --- 调用聊天服务获取回复 ---
try: try:
# Pass db session to the service method
ai_reply = await chat_service.get_ai_reply( ai_reply = await chat_service.get_ai_reply(
db=db,
user_message=user_message, user_message=user_message,
session_id=session_id, # 使用真实的或新创建的 session_id session_id=session_id,
assistant_id=assistant_id assistant_id=assistant_id
) )
print(f"发送 AI 回复: '{ai_reply}'") print(f"发送 AI 回复: '{ai_reply}'")
return ChatResponse( return ChatResponse(
reply=ai_reply, reply=ai_reply,
session_id=response_session_id, # 返回新 ID (如果创建了) session_id=response_session_id,
session_title=response_session_title # 返回新标题 (如果创建了) session_title=response_session_title
) )
except ValueError as e: # 助手不存在等错误 except ValueError as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
except Exception as e: # LLM 调用或其他错误 except Exception as e:
print(f"处理聊天消息时发生错误: {e}") print(f"处理聊天消息时发生错误: {e}")
# The get_db_session dependency will handle rollback
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)) raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e))

View File

@ -0,0 +1,33 @@
# File: backend/app/api/v1/endpoints/messages.py (New)
# Description: API endpoint for fetching messages
from fastapi import APIRouter, Depends, HTTPException, status, Query
from typing import List
from sqlalchemy.ext.asyncio import AsyncSession
from app.db.database import get_db_session
from app.models.pydantic_models import MessageRead
from app.db.models import MessageModel # Import DB model
from sqlalchemy.future import select
router = APIRouter()
@router.get("/session/{session_id}", response_model=List[MessageRead])
async def read_messages_for_session(
session_id: str,
db: AsyncSession = Depends(get_db_session),
skip: int = Query(0, ge=0), # Offset for pagination
limit: int = Query(100, ge=1, le=500) # Limit number of messages
):
"""获取指定会话的消息列表 (按时间顺序)"""
# TODO: Add check if session exists
stmt = (
select(MessageModel)
.filter(MessageModel.session_id == session_id)
.order_by(MessageModel.order.asc()) # Fetch in chronological order
.offset(skip)
.limit(limit)
)
result = await db.execute(stmt)
messages = result.scalars().all()
# Validate using Pydantic model before returning
return [MessageRead.model_validate(msg) for msg in messages]

View File

@ -1,47 +1,43 @@
# File: backend/app/api/v1/endpoints/sessions.py (新建) # File: backend/app/api/v1/endpoints/sessions.py (Update with DB session dependency)
# Description: 会话管理的 API 路由 # Description: 会话管理的 API 路由 (使用数据库会话)
from fastapi import APIRouter, HTTPException, Depends, status from fastapi import APIRouter, HTTPException, Depends, status
from typing import List from typing import List
from sqlalchemy.ext.asyncio import AsyncSession
from app.db.database import get_db_session
from app.models.pydantic_models import SessionRead, SessionCreateRequest, SessionCreateResponse from app.models.pydantic_models import SessionRead, SessionCreateRequest, SessionCreateResponse
from app.services.session_service import session_service_instance, SessionService from app.services.session_service import SessionService # Import the class
router = APIRouter() router = APIRouter()
session_service = SessionService() # Create instance
def get_session_service() -> SessionService:
return session_service_instance
@router.post("/", response_model=SessionCreateResponse, status_code=status.HTTP_201_CREATED) @router.post("/", response_model=SessionCreateResponse, status_code=status.HTTP_201_CREATED)
async def create_new_session( async def create_new_session(
session_data: SessionCreateRequest, session_data: SessionCreateRequest,
service: SessionService = Depends(get_session_service) db: AsyncSession = Depends(get_db_session) # Inject DB session
): ):
"""创建新会话并自动生成标题"""
try: try:
return await service.create_session(session_data) return await session_service.create_session(db, session_data)
except ValueError as e: except ValueError as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
except Exception as e: except Exception as e:
# 处理可能的 LLM 调用错误
print(f"创建会话时出错: {e}") print(f"创建会话时出错: {e}")
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="创建会话失败") raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="创建会话失败")
@router.get("/assistant/{assistant_id}", response_model=List[SessionRead]) @router.get("/assistant/{assistant_id}", response_model=List[SessionRead])
async def read_sessions_for_assistant( async def read_sessions_for_assistant(
assistant_id: str, assistant_id: str,
service: SessionService = Depends(get_session_service) db: AsyncSession = Depends(get_db_session)
): ):
"""获取指定助手的所有会话列表""" # Consider adding check if assistant exists first
# TODO: 添加检查助手是否存在 return await session_service.get_sessions_by_assistant(db, assistant_id)
return service.get_sessions_by_assistant(assistant_id)
@router.get("/{session_id}", response_model=SessionRead) @router.get("/{session_id}", response_model=SessionRead)
async def read_session_by_id( async def read_session_by_id(
session_id: str, session_id: str,
service: SessionService = Depends(get_session_service) db: AsyncSession = Depends(get_db_session)
): ):
"""获取单个会话信息""" session = await session_service.get_session(db, session_id)
session = service.get_session(session_id)
if not session: if not session:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="找不到指定的会话") raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="找不到指定的会话")
return session return session
@ -49,10 +45,8 @@ async def read_session_by_id(
@router.delete("/{session_id}", status_code=status.HTTP_204_NO_CONTENT) @router.delete("/{session_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_existing_session( async def delete_existing_session(
session_id: str, session_id: str,
service: SessionService = Depends(get_session_service) db: AsyncSession = Depends(get_db_session)
): ):
"""删除指定 ID 的会话""" deleted = await session_service.delete_session(db, session_id)
deleted = service.delete_session(session_id)
if not deleted: if not deleted:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="找不到指定的会话") raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="找不到指定的会话")

View File

@ -12,4 +12,6 @@ load_dotenv()
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY") # 如果使用 Google GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY") # 如果使用 Google
# 可以在这里添加其他配置项 # Define the database URL (SQLite in this case)
# DATABASE_URL = "sqlite+aiosqlite:///./cherryai.db" # Use async driver
DATABASE_URL = os.getenv("DATABASE_URL", "sqlite+aiosqlite:///./cherryai.db")

View File

@ -0,0 +1,43 @@
# File: backend/app/db/database.py (New - Database setup)
# Description: SQLAlchemy 数据库引擎和会话设置
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
from sqlalchemy.orm import sessionmaker, declarative_base
from app.core.config import DATABASE_URL
# 创建异步数据库引擎
# connect_args={"check_same_thread": False} is needed only for SQLite.
# It's not needed for other databases.
engine = create_async_engine(DATABASE_URL, echo=True, connect_args={"check_same_thread": False})
# 创建异步会话工厂
# expire_on_commit=False prevents attributes from expiring after commit.
AsyncSessionFactory = sessionmaker(
bind=engine,
class_=AsyncSession,
expire_on_commit=False,
autocommit=False,
autoflush=False,
)
# 创建数据模型的基础类
Base = declarative_base()
# --- Dependency to get DB session ---
async def get_db_session() -> AsyncSession:
"""FastAPI dependency to get an async database session."""
async with AsyncSessionFactory() as session:
try:
yield session
await session.commit() # Commit transaction if successful
except Exception:
await session.rollback() # Rollback on error
raise
finally:
await session.close()
# --- Function to create tables (call this on startup) ---
async def create_db_and_tables():
async with engine.begin() as conn:
# await conn.run_sync(Base.metadata.drop_all) # Use drop_all carefully in dev
await conn.run_sync(Base.metadata.create_all)

52
backend/app/db/models.py Normal file
View File

@ -0,0 +1,52 @@
# File: backend/app/db/models.py (New - Database models)
# Description: SQLAlchemy ORM 模型定义
from sqlalchemy import Column, String, Float, ForeignKey, Text, DateTime, Integer
from sqlalchemy.orm import relationship
from sqlalchemy.sql import func # For default timestamps
from app.db.database import Base
import uuid
from datetime import datetime, timezone
def generate_uuid():
return str(uuid.uuid4())
class AssistantModel(Base):
__tablename__ = "assistants"
id = Column(String, primary_key=True, default=generate_uuid)
name = Column(String(50), nullable=False, index=True)
description = Column(String(200), nullable=True)
avatar = Column(String(5), nullable=True)
system_prompt = Column(Text, nullable=False)
model = Column(String, nullable=False)
temperature = Column(Float, nullable=False, default=0.7)
created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
sessions = relationship("SessionModel", back_populates="assistant", cascade="all, delete-orphan")
class SessionModel(Base):
__tablename__ = "sessions"
id = Column(String, primary_key=True, default=generate_uuid)
title = Column(String(100), nullable=False, default="New Chat")
assistant_id = Column(String, ForeignKey("assistants.id"), nullable=False, index=True)
created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(DateTime(timezone=True), onupdate=func.now(), index=True) # Index for sorting
assistant = relationship("AssistantModel", back_populates="sessions")
messages = relationship("MessageModel", back_populates="session", cascade="all, delete-orphan", order_by="MessageModel.created_at") # Order messages by time
class MessageModel(Base):
__tablename__ = "messages"
id = Column(String, primary_key=True, default=generate_uuid)
session_id = Column(String, ForeignKey("sessions.id"), nullable=False, index=True)
sender = Column(String(10), nullable=False) # 'user' or 'ai' or 'system'
text = Column(Text, nullable=False)
order = Column(Integer, nullable=False) # Explicit order within session
created_at = Column(DateTime(timezone=True), server_default=func.now())
session = relationship("SessionModel", back_populates="messages")

View File

@ -1,20 +1,31 @@
# File: backend/app/main.py (确认 load_dotenv 调用位置) # File: backend/app/main.py (Update - Add startup event)
# Description: FastAPI 应用入口 # Description: FastAPI 应用入口 (添加数据库初始化)
from fastapi import FastAPI from fastapi import FastAPI
from starlette.middleware.cors import CORSMiddleware from starlette.middleware.cors import CORSMiddleware
from app.api.v1.api import api_router as api_router_v1 from app.api.v1.api import api_router as api_router_v1
# 确保在创建 FastAPI 实例之前加载环境变量 import app.core.config # Ensure config is loaded
from app.core.config import OPENAI_API_KEY # 导入会触发 load_dotenv from app.db.database import create_db_and_tables # Import table creation function
from contextlib import asynccontextmanager
# 创建 FastAPI 应用实例 # --- Lifespan context manager for startup/shutdown events ---
app = FastAPI(title="CherryAI Backend", version="0.1.0") @asynccontextmanager
async def lifespan(app: FastAPI):
# Startup actions
print("应用程序启动中...")
await create_db_and_tables() # Create database tables on startup
print("数据库表已检查/创建。")
# You can add the default assistant creation here if needed,
# but doing it in the service/model definition might be simpler for defaults.
yield
# Shutdown actions
print("应用程序关闭中...")
# --- 配置 CORS --- # Create FastAPI app with lifespan context manager
origins = [ app = FastAPI(title="CherryAI Backend", version="0.1.0", lifespan=lifespan)
"http://localhost:3000",
"http://127.0.0.1:3000", # --- CORS Middleware ---
] origins = [ "http://localhost:3000", "http://127.0.0.1:3000" ]
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,
allow_origins=origins, allow_origins=origins,
@ -23,10 +34,10 @@ app.add_middleware(
allow_headers=["*"], allow_headers=["*"],
) )
# --- 挂载 API 路由 --- # --- API Routers ---
app.include_router(api_router_v1, prefix="/api/v1") app.include_router(api_router_v1, prefix="/api/v1")
# --- 根路径 --- # --- Root Endpoint ---
@app.get("/", tags=["Root"]) @app.get("/", tags=["Root"])
async def read_root(): async def read_root():
return {"message": "欢迎来到 CherryAI 后端!"} return {"message": "欢迎来到 CherryAI 后端!"}

View File

@ -1,9 +1,10 @@
# File: backend/app/models/pydantic_models.py (更新) # File: backend/app/models/pydantic_models.py (Update Read models, add Message models)
# Description: Pydantic 模型定义 API 数据结构 # Description: Pydantic 模型定义 API 数据结构
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing import Optional, List from typing import Optional, List
import uuid # 用于生成唯一 ID import uuid
from datetime import datetime # Use datetime directly
# --- Assistant Models --- # --- Assistant Models ---
@ -33,7 +34,8 @@ class AssistantUpdate(BaseModel):
class AssistantRead(AssistantBase): class AssistantRead(AssistantBase):
"""读取助手信息时返回的模型 (包含 ID)""" """读取助手信息时返回的模型 (包含 ID)"""
id: str = Field(..., description="助手唯一 ID") id: str = Field(..., description="助手唯一 ID")
created_at: datetime # Add timestamps
updated_at: Optional[datetime] = None
class Config: class Config:
from_attributes = True # Pydantic v2: orm_mode = True from_attributes = True # Pydantic v2: orm_mode = True
@ -70,7 +72,22 @@ class SessionRead(BaseModel):
id: str id: str
title: str title: str
assistant_id: str assistant_id: str
created_at: str created_at: datetime # Use datetime
updated_at: Optional[datetime] = None
class Config:
from_attributes = True
# --- Message Models (New) ---
class MessageBase(BaseModel):
sender: str # 'user' or 'ai'
text: str
class MessageRead(MessageBase):
id: str
session_id: str
order: int
created_at: datetime
class Config: class Config:
from_attributes = True from_attributes = True

View File

@ -1,73 +1,76 @@
# File: backend/app/services/assistant_service.py (新建) # File: backend/app/services/assistant_service.py (Update with DB)
# Description: 管理助手数据的服务 (内存实现) # Description: 管理助手数据的服务 (使用 SQLAlchemy)
from typing import Dict, List, Optional from typing import List, Optional
from app.models.pydantic_models import AssistantRead, AssistantCreate, AssistantUpdate from sqlalchemy.ext.asyncio import AsyncSession
import uuid from sqlalchemy.future import select
from sqlalchemy import update as sqlalchemy_update, delete as sqlalchemy_delete
# 使用字典作为内存数据库存储助手 from app.db.models import AssistantModel
# key: assistant_id (str), value: AssistantRead object from app.models.pydantic_models import AssistantCreate, AssistantUpdate, AssistantRead
assistants_db: Dict[str, AssistantRead] = {}
# 添加默认助手 (确保 ID 与前端 Mock 一致)
default_assistant = AssistantRead(
id='asst-default',
name='默认助手',
description='通用聊天助手',
avatar='🤖',
system_prompt='你是一个乐于助人的 AI 助手。',
model='gpt-3.5-turbo',
temperature=0.7
)
assistants_db[default_assistant.id] = default_assistant
class AssistantService: class AssistantService:
"""助手数据的 CRUD 服务""" """助手数据的 CRUD 服务 (数据库版)"""
def get_assistants(self) -> List[AssistantRead]: async def get_assistants(self, db: AsyncSession) -> List[AssistantRead]:
"""获取所有助手""" """获取所有助手"""
return list(assistants_db.values()) result = await db.execute(select(AssistantModel).order_by(AssistantModel.name))
assistants = result.scalars().all()
return [AssistantRead.model_validate(a) for a in assistants] # Use model_validate in Pydantic v2
def get_assistant(self, assistant_id: str) -> Optional[AssistantRead]: async def get_assistant(self, db: AsyncSession, assistant_id: str) -> Optional[AssistantRead]:
"""根据 ID 获取单个助手""" """根据 ID 获取单个助手"""
return assistants_db.get(assistant_id) result = await db.execute(select(AssistantModel).filter(AssistantModel.id == assistant_id))
assistant = result.scalars().first()
return AssistantRead.model_validate(assistant) if assistant else None
def create_assistant(self, assistant_data: AssistantCreate) -> AssistantRead: async def create_assistant(self, db: AsyncSession, assistant_data: AssistantCreate) -> AssistantRead:
"""创建新助手""" """创建新助手"""
new_id = f"asst-{uuid.uuid4()}" # 生成唯一 ID # 使用 Pydantic 模型创建 DB 模型实例
new_assistant = AssistantRead(id=new_id, **assistant_data.model_dump()) db_assistant = AssistantModel(**assistant_data.model_dump())
assistants_db[new_id] = new_assistant # ID will be generated by default in the model
print(f"助手已创建: {new_id} - {new_assistant.name}") db.add(db_assistant)
return new_assistant await db.flush() # Flush to get the generated ID and defaults
await db.refresh(db_assistant) # Refresh to load all attributes
print(f"助手已创建 (DB): {db_assistant.id} - {db_assistant.name}")
return AssistantRead.model_validate(db_assistant)
def update_assistant(self, assistant_id: str, assistant_data: AssistantUpdate) -> Optional[AssistantRead]: async def update_assistant(self, db: AsyncSession, assistant_id: str, assistant_data: AssistantUpdate) -> Optional[AssistantRead]:
"""更新现有助手""" """更新现有助手"""
existing_assistant = assistants_db.get(assistant_id) update_values = assistant_data.model_dump(exclude_unset=True)
if not existing_assistant: if not update_values:
return None # If nothing to update, just fetch and return the existing one
return await self.get_assistant(db, assistant_id)
# 使用 Pydantic 的 model_copy 和 update 来更新字段 # Execute update statement
update_data = assistant_data.model_dump(exclude_unset=True) # 只获取设置了值的字段 stmt = (
if update_data: sqlalchemy_update(AssistantModel)
updated_assistant = existing_assistant.model_copy(update=update_data) .where(AssistantModel.id == assistant_id)
assistants_db[assistant_id] = updated_assistant .values(**update_values)
print(f"助手已更新: {assistant_id}") .returning(AssistantModel) # Return the updated row
return updated_assistant )
return existing_assistant # 如果没有更新任何字段,返回原始助手 result = await db.execute(stmt)
updated_assistant = result.scalars().first()
def delete_assistant(self, assistant_id: str) -> bool: if updated_assistant:
await db.flush()
await db.refresh(updated_assistant)
print(f"助手已更新 (DB): {assistant_id}")
return AssistantRead.model_validate(updated_assistant)
return None # Assistant not found
async def delete_assistant(self, db: AsyncSession, assistant_id: str) -> bool:
"""删除助手""" """删除助手"""
if assistant_id in assistants_db: # Prevent deleting default assistant
# 添加逻辑:不允许删除默认助手 if assistant_id == 'asst-default': # Assuming 'asst-default' is a known ID
if assistant_id == 'asst-default': print("尝试删除默认助手 - 操作被阻止")
print("尝试删除默认助手 - 操作被阻止") return False
return False # 或者抛出特定异常
del assistants_db[assistant_id] stmt = sqlalchemy_delete(AssistantModel).where(AssistantModel.id == assistant_id)
print(f"助手已删除: {assistant_id}") result = await db.execute(stmt)
# TODO: 在实际应用中,还需要删除关联的会话和消息 if result.rowcount > 0:
await db.flush()
print(f"助手已删除 (DB): {assistant_id}")
# Deletion of sessions/messages handled by cascade="all, delete-orphan"
return True return True
return False return False
# 创建服务实例
assistant_service_instance = AssistantService()

View File

@ -1,127 +1,137 @@
# File: backend/app/services/chat_service.py (更新) # File: backend/app/services/chat_service.py (Update with DB for history)
# Description: 封装 LangChain 聊天逻辑 (支持助手配置和会话历史) # Description: 封装 LangChain 聊天逻辑 (使用数据库存储和检索消息)
from langchain_openai import ChatOpenAI from langchain_openai import ChatOpenAI
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.output_parsers import StrOutputParser from langchain_core.output_parsers import StrOutputParser
from langchain_core.messages import HumanMessage, AIMessage, BaseMessage, SystemMessage from langchain_core.messages import HumanMessage, AIMessage, BaseMessage, SystemMessage
from typing import Dict, List, Optional from typing import Dict, List, Optional, Tuple
from app.services.assistant_service import assistant_service_instance # 获取助手配置 from sqlalchemy.ext.asyncio import AsyncSession
from app.models.pydantic_models import AssistantRead # 引入助手模型 from sqlalchemy.future import select
import app.core.config as Config from app.db.models import MessageModel, AssistantModel # Import DB models
from app.services.assistant_service import AssistantService # Use class directly
# --- 更新内存管理 --- from app.models.pydantic_models import AssistantRead
# 使用字典存储不同会话的内存
# key: session_id (str), value: List[BaseMessage]
chat_history_db: Dict[str, List[BaseMessage]] = {}
class ChatService: class ChatService:
"""处理 AI 聊天交互的服务 (支持助手配置)""" """处理 AI 聊天交互的服务 (使用数据库历史)"""
def __init__(self, default_api_key: str): def __init__(self, default_api_key: str):
"""初始化时可传入默认 API Key"""
self.default_api_key = default_api_key self.default_api_key = default_api_key
# 不再在 init 中创建固定的 LLM 和 chain self.assistant_service = AssistantService() # Instantiate assistant service
def _get_llm(self, assistant: AssistantRead) -> ChatOpenAI: def _get_llm(self, assistant: AssistantRead) -> ChatOpenAI:
"""根据助手配置动态创建 LLM 实例""" # ... (remains the same) ...
# TODO: 支持不同模型提供商 (Gemini, Anthropic etc.)
if assistant.model.startswith("gpt"): if assistant.model.startswith("gpt"):
return ChatOpenAI( return ChatOpenAI(model=assistant.model, api_key=self.default_api_key, temperature=assistant.temperature)
model=assistant.model,
api_key=self.default_api_key, # 或从助手配置中读取特定 key
temperature=assistant.temperature
)
elif assistant.model.startswith("gemini"): elif assistant.model.startswith("gemini"):
from langchain_google_genai import ChatGoogleGenerativeAI
return ChatGoogleGenerativeAI( return ChatGoogleGenerativeAI(
model=assistant.model, model=assistant.model,
api_key=self.default_api_key, # 或从助手配置中读取特定 key api_key=self.default_api_key, # 或从助手配置中读取特定 key
temperature=assistant.temperature temperature=assistant.temperature
) )
else: else:
# 默认或抛出错误
print(f"警告: 模型 {assistant.model} 未明确支持,尝试使用 ChatOpenAI") print(f"警告: 模型 {assistant.model} 未明确支持,尝试使用 ChatOpenAI")
return ChatOpenAI( return ChatOpenAI(model=assistant.model, api_key=self.default_api_key, temperature=assistant.temperature)
model=assistant.model,
api_key=self.default_api_key,
temperature=assistant.temperature
)
async def get_ai_reply(self, user_message: str, session_id: str, assistant_id: str) -> str:
""" async def _get_chat_history(self, db: AsyncSession, session_id: str, limit: int = 10) -> Tuple[List[BaseMessage], int]:
获取 AI 对用户消息的回复 (使用指定助手和会话历史) """从数据库加载指定会话的历史消息 (按 order 排序)"""
Args: stmt = (
user_message (str): 用户发送的消息 select(MessageModel)
session_id (str): 会话 ID .filter(MessageModel.session_id == session_id)
assistant_id (str): 使用的助手 ID .order_by(MessageModel.order.desc()) # Get latest first
Returns: .limit(limit)
str: AI 的回复文本 )
Raises: result = await db.execute(stmt)
ValueError: 如果找不到指定的助手 db_messages = result.scalars().all()
Exception: 如果调用 AI 服务时发生错误
""" # Convert to LangChain messages (in correct order: oldest to newest)
history: List[BaseMessage] = []
max_order = 0
for msg in reversed(db_messages): # Reverse to get chronological order
if msg.sender == 'user':
history.append(HumanMessage(content=msg.text))
elif msg.sender == 'ai':
history.append(AIMessage(content=msg.text))
# Add handling for 'system' if needed
max_order = max(max_order, msg.order) # Keep track of the latest order number
return history, max_order
async def _save_message(self, db: AsyncSession, session_id: str, sender: str, text: str, order: int):
"""将消息保存到数据库"""
db_message = MessageModel(
session_id=session_id,
sender=sender,
text=text,
order=order
)
db.add(db_message)
await db.flush() # Ensure it's added before potential commit
print(f"消息已保存 (DB): Session={session_id}, Order={order}, Sender={sender}")
async def get_ai_reply(self, db: AsyncSession, user_message: str, session_id: str, assistant_id: str) -> str:
"""获取 AI 回复,并保存用户消息和 AI 回复到数据库"""
# 1. 获取助手配置 # 1. 获取助手配置
assistant = assistant_service_instance.get_assistant(assistant_id) assistant = await self.assistant_service.get_assistant(db, assistant_id)
if not assistant: if not assistant:
raise ValueError(f"找不到助手 ID: {assistant_id}") raise ValueError(f"找不到助手 ID: {assistant_id}")
# 2. 获取或初始化当前会话的历史记录 # 2. 获取历史记录和下一个序号
current_chat_history = chat_history_db.get(session_id, []) current_chat_history, last_order = await self._get_chat_history(db, session_id)
user_message_order = last_order + 1
ai_message_order = last_order + 2
# 3. 构建 Prompt (包含动态系统提示) # 3. 构建 Prompt
prompt = ChatPromptTemplate.from_messages([ prompt = ChatPromptTemplate.from_messages([
SystemMessage(content=assistant.system_prompt), # 使用助手的系统提示 SystemMessage(content=assistant.system_prompt),
MessagesPlaceholder(variable_name="chat_history"), MessagesPlaceholder(variable_name="chat_history"),
HumanMessage(content="{input}"), HumanMessage(content=user_message),
]) ])
# 4. 获取 LLM 实例 # 4. 获取 LLM
llm = self._get_llm(assistant) llm = self._get_llm(assistant)
# 5. 定义输出解析器
output_parser = StrOutputParser() output_parser = StrOutputParser()
# 6. 构建 LCEL 链
chain = prompt | llm | output_parser chain = prompt | llm | output_parser
try: try:
# 7. 调用链获取回复 # --- Save user message BEFORE calling LLM ---
await self._save_message(db, session_id, 'user', user_message, user_message_order)
# 5. 调用链获取回复
ai_response = await chain.ainvoke({ ai_response = await chain.ainvoke({
"input": user_message, "input": user_message,
"chat_history": current_chat_history, "chat_history": current_chat_history, # Pass history fetched from DB
}) })
# 8. 更新会话历史记录 # --- Save AI response AFTER getting it ---
current_chat_history.append(HumanMessage(content=user_message)) await self._save_message(db, session_id, 'ai', ai_response, ai_message_order)
current_chat_history.append(AIMessage(content=ai_response))
# 限制历史记录长度 (例如最近 10 条消息)
max_history_length = 10
if len(current_chat_history) > max_history_length:
chat_history_db[session_id] = current_chat_history[-max_history_length:]
else:
chat_history_db[session_id] = current_chat_history
# Note: We don't need to manage history in memory anymore (chat_history_db removed)
return ai_response return ai_response
except Exception as e: except Exception as e:
# Consider rolling back the user message save if LLM call fails,
# although often it's better to keep the user message.
# await db.rollback() # Handled by get_db_session dependency on error
print(f"调用 LangChain 时出错 (助手: {assistant_id}, 会话: {session_id}): {e}") print(f"调用 LangChain 时出错 (助手: {assistant_id}, 会话: {session_id}): {e}")
raise Exception(f"AI 服务暂时不可用: {e}") raise Exception(f"AI 服务暂时不可用: {e}")
# (可选) 添加一个简单的文本生成方法用于生成标题 async def generate_text(self, prompt_text: str, model_name: str = "gemini-2.0-flash", temperature: float = 0.5) -> str:
async def generate_text(self, prompt_text: str, model_name: str = "gpt-3.5-turbo", temperature: float = 0.5) -> str: # ... (remains the same) ...
"""使用指定模型生成文本 (用于标题等)""" try:
try: temp_llm = ChatGoogleGenerativeAI(
# 使用一个临时的、可能更便宜的模型 model=model_name,
temp_llm = ChatOpenAI(model=model_name, api_key=self.default_api_key, temperature=temperature) api_key=self.default_api_key, # 或从助手配置中读取特定 key
response = await temp_llm.ainvoke(prompt_text) temperature=temperature
return response.content )
except Exception as e: response = await temp_llm.ainvoke(prompt_text)
print(f"生成文本时出错: {e}") return response.content
return "无法生成标题" # 返回默认值或抛出异常 except Exception as e:
print(f"生成文本时出错: {e}")
return "无法生成标题"
# --- 创建 ChatService 实例 --- # ChatService instance is now created where needed or injected, no global instance here.
if not Config.GOOGLE_API_KEY:
raise ValueError("请在 .env 文件中设置 OPENAI_API_KEY")
chat_service_instance = ChatService(default_api_key=Config.GOOGLE_API_KEY)

View File

@ -1,82 +1,74 @@
# File: backend/app/services/session_service.py (新建) # File: backend/app/services/session_service.py (Update with DB)
# Description: 管理会话数据的服务 (内存实现) # Description: 管理会话数据的服务 (使用 SQLAlchemy)
from typing import Dict, List, Optional from typing import List, Optional
from app.models.pydantic_models import SessionRead, SessionCreateRequest, SessionCreateResponse, AssistantRead from sqlalchemy.ext.asyncio import AsyncSession
from app.services.assistant_service import assistant_service_instance # 需要获取助手信息 from sqlalchemy.future import select
from sqlalchemy import delete as sqlalchemy_delete
from app.db.models import SessionModel, AssistantModel # Import DB models
from app.models.pydantic_models import SessionRead, SessionCreateRequest, SessionCreateResponse
from datetime import datetime, timezone from datetime import datetime, timezone
import uuid # Import ChatService for title generation (consider refactoring later)
# 导入 ChatService 以调用 LLM 生成标题 (避免循环导入,考虑重构) from app.services.chat_service import ChatService
# from app.services.chat_service import chat_service_instance import app.core.config as Config
chat_service_instance = ChatService(Config.GOOGLE_API_KEY)
# 内存数据库存储会话
# key: session_id (str), value: SessionRead object
sessions_db: Dict[str, SessionRead] = {}
class SessionService: class SessionService:
"""会话数据的 CRUD 及标题生成服务""" """会话数据的 CRUD 及标题生成服务 (数据库版)"""
async def create_session(self, session_data: SessionCreateRequest) -> SessionCreateResponse: async def create_session(self, db: AsyncSession, session_data: SessionCreateRequest) -> SessionCreateResponse:
"""创建新会话并生成标题""" """创建新会话并生成标题"""
assistant = assistant_service_instance.get_assistant(session_data.assistant_id) # 检查助手是否存在
result = await db.execute(select(AssistantModel).filter(AssistantModel.id == session_data.assistant_id))
assistant = result.scalars().first()
if not assistant: if not assistant:
raise ValueError("指定的助手不存在") raise ValueError("指定的助手不存在")
new_id = f"session-{uuid.uuid4()}" # --- 调用 LLM 生成标题 ---
created_time = datetime.now(timezone.utc) try:
title_prompt = f"根据以下用户第一条消息为此对话生成一个简洁的标题不超过10个字:\n\n{session_data.first_message}"
generated_title = await chat_service_instance.generate_text(title_prompt)
except Exception as e:
print(f"生成会话标题时出错: {e}")
generated_title = f"关于 \"{session_data.first_message[:15]}...\"" # Fallback
# --- 生成结束 ---
# --- TODO: 调用 LLM 生成标题 --- db_session = SessionModel(
# title_prompt = f"根据以下用户第一条消息为此对话生成一个简洁的标题不超过10个字:\n\n{session_data.first_message}"
# generated_title = await chat_service_instance.generate_text(title_prompt) # 需要一个简单的文本生成方法
# 模拟标题生成
generated_title = f"关于 \"{session_data.first_message[:15]}...\""
print(f"为新会话 {new_id} 生成标题: {generated_title}")
# --- 模拟结束 ---
new_session = SessionRead(
id=new_id,
title=generated_title, title=generated_title,
assistant_id=session_data.assistant_id, assistant_id=session_data.assistant_id
created_at=created_time.isoformat() # 存储 ISO 格式字符串 # ID and created_at have defaults
) )
sessions_db[new_id] = new_session db.add(db_session)
print(f"会话已创建: {new_id}") await db.flush()
await db.refresh(db_session)
print(f"会话已创建 (DB): {db_session.id}")
return SessionCreateResponse( return SessionCreateResponse(
id=new_session.id, id=db_session.id,
title=new_session.title, title=db_session.title,
assistant_id=new_session.assistant_id, assistant_id=db_session.assistant_id,
created_at=new_session.created_at created_at=db_session.created_at.isoformat() # Use datetime from DB model
) )
def get_sessions_by_assistant(self, assistant_id: str) -> List[SessionRead]: async def get_sessions_by_assistant(self, db: AsyncSession, assistant_id: str) -> List[SessionRead]:
"""获取指定助手的所有会话""" """获取指定助手的所有会话"""
return [s for s in sessions_db.values() if s.assistant_id == assistant_id] stmt = select(SessionModel).filter(SessionModel.assistant_id == assistant_id).order_by(SessionModel.updated_at.desc()) # Order by update time
result = await db.execute(stmt)
sessions = result.scalars().all()
return [SessionRead.model_validate(s) for s in sessions]
def get_session(self, session_id: str) -> Optional[SessionRead]: async def get_session(self, db: AsyncSession, session_id: str) -> Optional[SessionRead]:
"""获取单个会话""" """获取单个会话"""
return sessions_db.get(session_id) result = await db.execute(select(SessionModel).filter(SessionModel.id == session_id))
session = result.scalars().first()
return SessionRead.model_validate(session) if session else None
def delete_session(self, session_id: str) -> bool: async def delete_session(self, db: AsyncSession, session_id: str) -> bool:
"""删除会话""" """删除会话"""
if session_id in sessions_db: stmt = sqlalchemy_delete(SessionModel).where(SessionModel.id == session_id)
del sessions_db[session_id] result = await db.execute(stmt)
print(f"会话已删除: {session_id}") if result.rowcount > 0:
# TODO: 删除关联的消息 await db.flush()
print(f"会话已删除 (DB): {session_id}")
# Deletion of messages handled by cascade
return True return True
return False return False
def delete_sessions_by_assistant(self, assistant_id: str) -> int:
"""删除指定助手的所有会话"""
ids_to_delete = [s.id for s in sessions_db.values() if s.assistant_id == assistant_id]
count = 0
for session_id in ids_to_delete:
if self.delete_session(session_id):
count += 1
print(f"删除了助手 {assistant_id}{count} 个会话")
return count
# 创建服务实例
session_service_instance = SessionService()

BIN
backend/cherryai.db Normal file

Binary file not shown.

File diff suppressed because it is too large Load Diff

View File

@ -45,7 +45,7 @@ export default function RootLayout({
<html lang="zh-CN"> <html lang="zh-CN">
<body className={`${geistSans.variable} ${geistMono.variable} antialiased flex h-screen bg-gray-100 dark:bg-gray-900`}> <body className={`${geistSans.variable} ${geistMono.variable} antialiased flex h-screen bg-gray-100 dark:bg-gray-900`}>
{/* 侧边栏导航 */} {/* 侧边栏导航 */}
<aside className="w-16 bg-white dark:bg-gray-800 p-3 shadow-md flex flex-col items-center mr-1"> {/* 调整内边距和对齐 */} <aside className="w-16 bg-white dark:bg-gray-800 p-3 shadow-md flex flex-col items-center rounded-lg"> {/* 调整内边距和对齐 */}
{/* Logo */} {/* Logo */}
<div className="mb-6"> {/* 调整 Logo 边距 */} <div className="mb-6"> {/* 调整 Logo 边距 */}
<Link href="/" className="flex items-center justify-center text-3xl font-bold text-red-600 dark:text-red-500" title="CherryAI 主页"> <Link href="/" className="flex items-center justify-center text-3xl font-bold text-red-600 dark:text-red-500" title="CherryAI 主页">
@ -60,12 +60,11 @@ export default function RootLayout({
<Link <Link
href={item.href} href={item.href}
className="relative flex items-center justify-center p-2 rounded-lg text-gray-600 dark:text-gray-400 hover:bg-red-100 dark:hover:bg-red-900/50 hover:text-red-700 dark:hover:text-red-400 transition-colors duration-200 group" // 居中图标 className="relative flex items-center justify-center p-2 rounded-lg text-gray-600 dark:text-gray-400 hover:bg-red-100 dark:hover:bg-red-900/50 hover:text-red-700 dark:hover:text-red-400 transition-colors duration-200 group" // 居中图标
title={item.name} // 保留原生 title
> >
<item.icon className="h-6 w-6 flex-shrink-0 text-gray-600 dark:text-gray-400" /> <item.icon className="h-6 w-6 flex-shrink-0 text-gray-600 dark:text-gray-400" />
{/* Tooltip 文字标签 */} {/* Tooltip 文字标签 */}
<span <span
className="absolute left-full top-1/2 -translate-y-1/2 ml-3 px-2 py-1 bg-gray-900 dark:bg-gray-700 text-white text-xl rounded shadow-lg opacity-0 group-hover:opacity-100 transition-opacity duration-200 delay-150 whitespace-nowrap pointer-events-none" // 使用 pointer-events-none 避免干扰悬浮 className="z-10 absolute left-full top-1/2 -translate-y-1/2 ml-3 px-2 py-1 bg-gray-900 dark:bg-gray-700 text-white text-xl rounded shadow-lg opacity-0 group-hover:opacity-100 transition-opacity duration-200 delay-150 whitespace-nowrap pointer-events-none" // 使用 pointer-events-none 避免干扰悬浮
> >
{item.name} {item.name}
</span> </span>

View File

@ -1,40 +1,28 @@
// File: frontend/lib/api.ts (更新) // File: frontend/lib/api.ts (更新)
// Description: 添加调用助手和会话管理 API 的函数 // Description: 添加调用助手和会话管理 API 的函数
import { Assistant, AssistantCreateData, AssistantUpdateData } from "@/types/assistant";
import axios from "axios"; import axios from "axios";
// --- Types (从后端模型同步或手动定义) --- // --- Types ---
// 这些类型应该与后端 pydantic_models.py 中的 Read 模型匹配
export interface Assistant {
id: string;
name: string;
description?: string | null;
avatar?: string | null;
system_prompt: string;
model: string;
temperature: number;
}
export interface Session { export interface Session {
id: string; id: string;
title: string; title: string;
assistant_id: string; assistant_id: string;
created_at: string; // ISO date string created_at: string; // ISO date string
updated_at?: string | null; // Add updated_at
} }
// 创建助手时发送的数据类型 // Message type from backend
export interface AssistantCreateData { export interface Message {
name: string; id: string;
description?: string | null; session_id: string;
avatar?: string | null; sender: 'user' | 'ai'; // Or extend with 'system' if needed
system_prompt: string; text: string;
model: string; order: number;
temperature: number; created_at: string; // ISO date string
} }
// 更新助手时发送的数据类型 (所有字段可选)
export type AssistantUpdateData = Partial<AssistantCreateData>;
// 聊天响应类型 // 聊天响应类型
export interface ChatApiResponse { export interface ChatApiResponse {
reply: string; reply: string;
@ -169,4 +157,19 @@ export const deleteSession = async (sessionId: string): Promise<void> => {
//当前端发送 sessionId 为 'temp-new-chat' 的消息时,后端会自动创建。 //当前端发送 sessionId 为 'temp-new-chat' 的消息时,后端会自动创建。
//如果需要单独创建会话(例如,不发送消息就创建),则需要单独实现前端调用 POST /sessions/。 //如果需要单独创建会话(例如,不发送消息就创建),则需要单独实现前端调用 POST /sessions/。
// TODO: 添加获取会话消息的 API 函数 (GET /sessions/{session_id}/messages) // --- Message API (New) ---
/** 获取指定会话的消息列表 */
export const getMessagesBySession = async (sessionId: string, limit: number = 100, skip: number = 0): Promise<Message[]> => {
try {
const response = await apiClient.get<Message[]>(`/messages/session/${sessionId}`, {
params: { limit, skip }
});
return response.data;
} catch (error) {
// Handle 404 specifically if needed (session exists but no messages)
if (axios.isAxiosError(error) && error.response?.status === 404) {
return []; // Return empty list if session not found or no messages
}
throw new Error(handleApiError(error, 'getMessagesBySession'));
}
};

View File

@ -0,0 +1,23 @@
// --- Types (从后端模型同步或手动定义) ---
// 这些类型应该与后端 pydantic_models.py 中的 Read 模型匹配
export interface Assistant {
id: string;
name: string;
description?: string | null;
avatar?: string | null;
system_prompt: string;
model: string;
temperature: number;
}
// 创建助手时发送的数据类型
export interface AssistantCreateData {
name: string;
description?: string | null;
avatar?: string | null;
system_prompt: string;
model: string;
temperature: number;
}
// 更新助手时发送的数据类型 (所有字段可选)
export type AssistantUpdateData = Partial<AssistantCreateData>;