添加数据管理

This commit is contained in:
adrian 2025-04-30 04:39:36 +08:00
parent 7df10e82be
commit f0863914c2
18 changed files with 1108 additions and 676 deletions

View File

@ -1,12 +1,12 @@
# File: backend/app/api/v1/api.py (更新)
# File: backend/app/api/v1/api.py (Update)
# Description: 聚合 v1 版本的所有 API 路由
from fastapi import APIRouter
from app.api.v1.endpoints import chat, assistants, sessions # 导入新路由
from app.api.v1.endpoints import chat, assistants, sessions, messages # Import messages router
api_router = APIRouter()
api_router.include_router(chat.router, prefix="/chat", tags=["Chat"])
api_router.include_router(assistants.router, prefix="/assistants", tags=["Assistants"]) # 添加助手路由
api_router.include_router(sessions.router, prefix="/sessions", tags=["Sessions"]) # 添加会话路由
api_router.include_router(assistants.router, prefix="/assistants", tags=["Assistants"])
api_router.include_router(sessions.router, prefix="/sessions", tags=["Sessions"])
api_router.include_router(messages.router, prefix="/messages", tags=["Messages"]) # Add messages router

View File

@ -1,39 +1,39 @@
# File: backend/app/api/v1/endpoints/assistants.py (新建)
# Description: 助手的 API 路由
# File: backend/app/api/v1/endpoints/assistants.py (Update with DB session dependency)
# Description: 助手的 API 路由 (使用数据库会话)
from fastapi import APIRouter, HTTPException, Depends, status
from typing import List
from sqlalchemy.ext.asyncio import AsyncSession
from app.db.database import get_db_session # Import DB session dependency
from app.models.pydantic_models import AssistantRead, AssistantCreate, AssistantUpdate
from app.services.assistant_service import assistant_service_instance, AssistantService
from app.services.assistant_service import AssistantService # Import the class
router = APIRouter()
# --- 依赖注入 AssistantService ---
def get_assistant_service() -> AssistantService:
return assistant_service_instance
# --- Dependency Injection for Service and DB Session ---
# Service instance can be created per request or globally
# For simplicity, let's create it here, but pass db session to methods
assistant_service = AssistantService()
@router.post("/", response_model=AssistantRead, status_code=status.HTTP_201_CREATED)
async def create_new_assistant(
assistant_data: AssistantCreate,
service: AssistantService = Depends(get_assistant_service)
db: AsyncSession = Depends(get_db_session) # Inject DB session
):
"""创建新助手"""
return service.create_assistant(assistant_data)
return await assistant_service.create_assistant(db, assistant_data)
@router.get("/", response_model=List[AssistantRead])
async def read_all_assistants(
service: AssistantService = Depends(get_assistant_service)
db: AsyncSession = Depends(get_db_session)
):
"""获取所有助手列表"""
return service.get_assistants()
return await assistant_service.get_assistants(db)
@router.get("/{assistant_id}", response_model=AssistantRead)
async def read_assistant_by_id(
assistant_id: str,
service: AssistantService = Depends(get_assistant_service)
db: AsyncSession = Depends(get_db_session)
):
"""根据 ID 获取特定助手"""
assistant = service.get_assistant(assistant_id)
assistant = await assistant_service.get_assistant(db, assistant_id)
if not assistant:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="找不到指定的助手")
return assistant
@ -42,10 +42,9 @@ async def read_assistant_by_id(
async def update_existing_assistant(
assistant_id: str,
assistant_data: AssistantUpdate,
service: AssistantService = Depends(get_assistant_service)
db: AsyncSession = Depends(get_db_session)
):
"""更新指定 ID 的助手"""
updated_assistant = service.update_assistant(assistant_id, assistant_data)
updated_assistant = await assistant_service.update_assistant(db, assistant_id, assistant_data)
if not updated_assistant:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="找不到指定的助手")
return updated_assistant
@ -53,14 +52,17 @@ async def update_existing_assistant(
@router.delete("/{assistant_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_existing_assistant(
assistant_id: str,
service: AssistantService = Depends(get_assistant_service)
db: AsyncSession = Depends(get_db_session)
):
"""删除指定 ID 的助手"""
deleted = service.delete_assistant(assistant_id)
# Handle potential error from service if trying to delete default
try:
deleted = await assistant_service.delete_assistant(db, assistant_id)
if not deleted:
# 根据服务层逻辑判断是找不到还是不允许删除
assistant = service.get_assistant(assistant_id)
# Check if it exists to differentiate 404 from 403 (or handle in service)
assistant = await assistant_service.get_assistant(db, assistant_id)
if assistant and assistant_id == 'asst-default':
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="不允许删除默认助手")
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="找不到指定的助手")
# 成功删除,不返回内容
except Exception as e: # Catch other potential DB errors
print(f"删除助手时出错: {e}")
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="删除助手失败")

View File

@ -1,27 +1,26 @@
# File: backend/app/api/v1/endpoints/chat.py (更新)
# Description: 聊天功能的 API 路由 (使用更新后的 ChatService)
# File: backend/app/api/v1/endpoints/chat.py (Update with DB session dependency)
# Description: 聊天功能的 API 路由 (使用数据库会话)
from fastapi import APIRouter, HTTPException, Depends, status
from sqlalchemy.ext.asyncio import AsyncSession
from app.db.database import get_db_session
from app.models.pydantic_models import ChatRequest, ChatResponse, SessionCreateRequest
from app.services.chat_service import chat_service_instance, ChatService
from app.services.session_service import session_service_instance, SessionService # 导入 SessionService
from app.services.chat_service import ChatService # Import class
from app.services.session_service import SessionService # Import class
import app.core.config as Config # Import API Key for ChatService instantiation
router = APIRouter()
# --- 依赖注入 ---
def get_chat_service() -> ChatService:
return chat_service_instance
def get_session_service() -> SessionService:
return session_service_instance
# --- Dependency Injection ---
# Instantiate services here or use a more sophisticated dependency injection system
chat_service = ChatService(default_api_key=Config.GOOGLE_API_KEY)
session_service = SessionService()
@router.post("/", response_model=ChatResponse)
async def handle_chat_message(
request: ChatRequest,
chat_service: ChatService = Depends(get_chat_service),
session_service: SessionService = Depends(get_session_service) # 注入 SessionService
db: AsyncSession = Depends(get_db_session) # Inject DB session
):
"""处理用户发送的聊天消息 (包含 assistantId 和 sessionId)"""
user_message = request.message
session_id = request.session_id
assistant_id = request.assistant_id
@ -31,38 +30,39 @@ async def handle_chat_message(
response_session_id = None
response_session_title = None
# --- 处理临时新会话 ---
if session_id == 'temp-new-chat':
print("检测到临时新会话,正在创建...")
try:
# 调用 SessionService 创建会话
create_req = SessionCreateRequest(assistant_id=assistant_id, first_message=user_message)
created_session = await session_service.create_session(create_req)
session_id = created_session.id # 使用新创建的会话 ID
response_session_id = created_session.id # 准备在响应中返回新 ID
response_session_title = created_session.title # 准备在响应中返回新标题
# Pass db session to the service method
created_session = await session_service.create_session(db, create_req)
session_id = created_session.id
response_session_id = created_session.id
response_session_title = created_session.title
print(f"新会话已创建: ID={session_id}, Title='{created_session.title}'")
except ValueError as e: # 助手不存在等错误
except ValueError as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
except Exception as e: # LLM 调用或其他错误
except Exception as e:
print(f"创建会话时出错: {e}")
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="创建会话失败")
# --- 调用聊天服务获取回复 ---
try:
# Pass db session to the service method
ai_reply = await chat_service.get_ai_reply(
db=db,
user_message=user_message,
session_id=session_id, # 使用真实的或新创建的 session_id
session_id=session_id,
assistant_id=assistant_id
)
print(f"发送 AI 回复: '{ai_reply}'")
return ChatResponse(
reply=ai_reply,
session_id=response_session_id, # 返回新 ID (如果创建了)
session_title=response_session_title # 返回新标题 (如果创建了)
session_id=response_session_id,
session_title=response_session_title
)
except ValueError as e: # 助手不存在等错误
except ValueError as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
except Exception as e: # LLM 调用或其他错误
except Exception as e:
print(f"处理聊天消息时发生错误: {e}")
# The get_db_session dependency will handle rollback
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e))

View File

@ -0,0 +1,33 @@
# File: backend/app/api/v1/endpoints/messages.py (New)
# Description: API endpoint for fetching messages
from fastapi import APIRouter, Depends, HTTPException, status, Query
from typing import List
from sqlalchemy.ext.asyncio import AsyncSession
from app.db.database import get_db_session
from app.models.pydantic_models import MessageRead
from app.db.models import MessageModel # Import DB model
from sqlalchemy.future import select
router = APIRouter()
@router.get("/session/{session_id}", response_model=List[MessageRead])
async def read_messages_for_session(
session_id: str,
db: AsyncSession = Depends(get_db_session),
skip: int = Query(0, ge=0), # Offset for pagination
limit: int = Query(100, ge=1, le=500) # Limit number of messages
):
"""获取指定会话的消息列表 (按时间顺序)"""
# TODO: Add check if session exists
stmt = (
select(MessageModel)
.filter(MessageModel.session_id == session_id)
.order_by(MessageModel.order.asc()) # Fetch in chronological order
.offset(skip)
.limit(limit)
)
result = await db.execute(stmt)
messages = result.scalars().all()
# Validate using Pydantic model before returning
return [MessageRead.model_validate(msg) for msg in messages]

View File

@ -1,47 +1,43 @@
# File: backend/app/api/v1/endpoints/sessions.py (新建)
# Description: 会话管理的 API 路由
# File: backend/app/api/v1/endpoints/sessions.py (Update with DB session dependency)
# Description: 会话管理的 API 路由 (使用数据库会话)
from fastapi import APIRouter, HTTPException, Depends, status
from typing import List
from sqlalchemy.ext.asyncio import AsyncSession
from app.db.database import get_db_session
from app.models.pydantic_models import SessionRead, SessionCreateRequest, SessionCreateResponse
from app.services.session_service import session_service_instance, SessionService
from app.services.session_service import SessionService # Import the class
router = APIRouter()
def get_session_service() -> SessionService:
return session_service_instance
session_service = SessionService() # Create instance
@router.post("/", response_model=SessionCreateResponse, status_code=status.HTTP_201_CREATED)
async def create_new_session(
session_data: SessionCreateRequest,
service: SessionService = Depends(get_session_service)
db: AsyncSession = Depends(get_db_session) # Inject DB session
):
"""创建新会话并自动生成标题"""
try:
return await service.create_session(session_data)
return await session_service.create_session(db, session_data)
except ValueError as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
except Exception as e:
# 处理可能的 LLM 调用错误
print(f"创建会话时出错: {e}")
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="创建会话失败")
@router.get("/assistant/{assistant_id}", response_model=List[SessionRead])
async def read_sessions_for_assistant(
assistant_id: str,
service: SessionService = Depends(get_session_service)
db: AsyncSession = Depends(get_db_session)
):
"""获取指定助手的所有会话列表"""
# TODO: 添加检查助手是否存在
return service.get_sessions_by_assistant(assistant_id)
# Consider adding check if assistant exists first
return await session_service.get_sessions_by_assistant(db, assistant_id)
@router.get("/{session_id}", response_model=SessionRead)
async def read_session_by_id(
session_id: str,
service: SessionService = Depends(get_session_service)
db: AsyncSession = Depends(get_db_session)
):
"""获取单个会话信息"""
session = service.get_session(session_id)
session = await session_service.get_session(db, session_id)
if not session:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="找不到指定的会话")
return session
@ -49,10 +45,8 @@ async def read_session_by_id(
@router.delete("/{session_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_existing_session(
session_id: str,
service: SessionService = Depends(get_session_service)
db: AsyncSession = Depends(get_db_session)
):
"""删除指定 ID 的会话"""
deleted = service.delete_session(session_id)
deleted = await session_service.delete_session(db, session_id)
if not deleted:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="找不到指定的会话")

View File

@ -12,4 +12,6 @@ load_dotenv()
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY") # 如果使用 Google
# 可以在这里添加其他配置项
# Define the database URL (SQLite in this case)
# DATABASE_URL = "sqlite+aiosqlite:///./cherryai.db" # Use async driver
DATABASE_URL = os.getenv("DATABASE_URL", "sqlite+aiosqlite:///./cherryai.db")

View File

@ -0,0 +1,43 @@
# File: backend/app/db/database.py (New - Database setup)
# Description: SQLAlchemy 数据库引擎和会话设置
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
from sqlalchemy.orm import sessionmaker, declarative_base
from app.core.config import DATABASE_URL
# 创建异步数据库引擎
# connect_args={"check_same_thread": False} is needed only for SQLite.
# It's not needed for other databases.
engine = create_async_engine(DATABASE_URL, echo=True, connect_args={"check_same_thread": False})
# 创建异步会话工厂
# expire_on_commit=False prevents attributes from expiring after commit.
AsyncSessionFactory = sessionmaker(
bind=engine,
class_=AsyncSession,
expire_on_commit=False,
autocommit=False,
autoflush=False,
)
# 创建数据模型的基础类
Base = declarative_base()
# --- Dependency to get DB session ---
async def get_db_session() -> AsyncSession:
"""FastAPI dependency to get an async database session."""
async with AsyncSessionFactory() as session:
try:
yield session
await session.commit() # Commit transaction if successful
except Exception:
await session.rollback() # Rollback on error
raise
finally:
await session.close()
# --- Function to create tables (call this on startup) ---
async def create_db_and_tables():
async with engine.begin() as conn:
# await conn.run_sync(Base.metadata.drop_all) # Use drop_all carefully in dev
await conn.run_sync(Base.metadata.create_all)

52
backend/app/db/models.py Normal file
View File

@ -0,0 +1,52 @@
# File: backend/app/db/models.py (New - Database models)
# Description: SQLAlchemy ORM 模型定义
from sqlalchemy import Column, String, Float, ForeignKey, Text, DateTime, Integer
from sqlalchemy.orm import relationship
from sqlalchemy.sql import func # For default timestamps
from app.db.database import Base
import uuid
from datetime import datetime, timezone
def generate_uuid():
return str(uuid.uuid4())
class AssistantModel(Base):
__tablename__ = "assistants"
id = Column(String, primary_key=True, default=generate_uuid)
name = Column(String(50), nullable=False, index=True)
description = Column(String(200), nullable=True)
avatar = Column(String(5), nullable=True)
system_prompt = Column(Text, nullable=False)
model = Column(String, nullable=False)
temperature = Column(Float, nullable=False, default=0.7)
created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
sessions = relationship("SessionModel", back_populates="assistant", cascade="all, delete-orphan")
class SessionModel(Base):
__tablename__ = "sessions"
id = Column(String, primary_key=True, default=generate_uuid)
title = Column(String(100), nullable=False, default="New Chat")
assistant_id = Column(String, ForeignKey("assistants.id"), nullable=False, index=True)
created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(DateTime(timezone=True), onupdate=func.now(), index=True) # Index for sorting
assistant = relationship("AssistantModel", back_populates="sessions")
messages = relationship("MessageModel", back_populates="session", cascade="all, delete-orphan", order_by="MessageModel.created_at") # Order messages by time
class MessageModel(Base):
__tablename__ = "messages"
id = Column(String, primary_key=True, default=generate_uuid)
session_id = Column(String, ForeignKey("sessions.id"), nullable=False, index=True)
sender = Column(String(10), nullable=False) # 'user' or 'ai' or 'system'
text = Column(Text, nullable=False)
order = Column(Integer, nullable=False) # Explicit order within session
created_at = Column(DateTime(timezone=True), server_default=func.now())
session = relationship("SessionModel", back_populates="messages")

View File

@ -1,20 +1,31 @@
# File: backend/app/main.py (确认 load_dotenv 调用位置)
# Description: FastAPI 应用入口
# File: backend/app/main.py (Update - Add startup event)
# Description: FastAPI 应用入口 (添加数据库初始化)
from fastapi import FastAPI
from starlette.middleware.cors import CORSMiddleware
from app.api.v1.api import api_router as api_router_v1
# 确保在创建 FastAPI 实例之前加载环境变量
from app.core.config import OPENAI_API_KEY # 导入会触发 load_dotenv
import app.core.config # Ensure config is loaded
from app.db.database import create_db_and_tables # Import table creation function
from contextlib import asynccontextmanager
# 创建 FastAPI 应用实例
app = FastAPI(title="CherryAI Backend", version="0.1.0")
# --- Lifespan context manager for startup/shutdown events ---
@asynccontextmanager
async def lifespan(app: FastAPI):
# Startup actions
print("应用程序启动中...")
await create_db_and_tables() # Create database tables on startup
print("数据库表已检查/创建。")
# You can add the default assistant creation here if needed,
# but doing it in the service/model definition might be simpler for defaults.
yield
# Shutdown actions
print("应用程序关闭中...")
# --- 配置 CORS ---
origins = [
"http://localhost:3000",
"http://127.0.0.1:3000",
]
# Create FastAPI app with lifespan context manager
app = FastAPI(title="CherryAI Backend", version="0.1.0", lifespan=lifespan)
# --- CORS Middleware ---
origins = [ "http://localhost:3000", "http://127.0.0.1:3000" ]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
@ -23,10 +34,10 @@ app.add_middleware(
allow_headers=["*"],
)
# --- 挂载 API 路由 ---
# --- API Routers ---
app.include_router(api_router_v1, prefix="/api/v1")
# --- 根路径 ---
# --- Root Endpoint ---
@app.get("/", tags=["Root"])
async def read_root():
return {"message": "欢迎来到 CherryAI 后端!"}

View File

@ -1,9 +1,10 @@
# File: backend/app/models/pydantic_models.py (更新)
# File: backend/app/models/pydantic_models.py (Update Read models, add Message models)
# Description: Pydantic 模型定义 API 数据结构
from pydantic import BaseModel, Field
from typing import Optional, List
import uuid # 用于生成唯一 ID
import uuid
from datetime import datetime # Use datetime directly
# --- Assistant Models ---
@ -33,7 +34,8 @@ class AssistantUpdate(BaseModel):
class AssistantRead(AssistantBase):
"""读取助手信息时返回的模型 (包含 ID)"""
id: str = Field(..., description="助手唯一 ID")
created_at: datetime # Add timestamps
updated_at: Optional[datetime] = None
class Config:
from_attributes = True # Pydantic v2: orm_mode = True
@ -70,7 +72,22 @@ class SessionRead(BaseModel):
id: str
title: str
assistant_id: str
created_at: str
created_at: datetime # Use datetime
updated_at: Optional[datetime] = None
class Config:
from_attributes = True
# --- Message Models (New) ---
class MessageBase(BaseModel):
sender: str # 'user' or 'ai'
text: str
class MessageRead(MessageBase):
id: str
session_id: str
order: int
created_at: datetime
class Config:
from_attributes = True

View File

@ -1,73 +1,76 @@
# File: backend/app/services/assistant_service.py (新建)
# Description: 管理助手数据的服务 (内存实现)
# File: backend/app/services/assistant_service.py (Update with DB)
# Description: 管理助手数据的服务 (使用 SQLAlchemy)
from typing import Dict, List, Optional
from app.models.pydantic_models import AssistantRead, AssistantCreate, AssistantUpdate
import uuid
# 使用字典作为内存数据库存储助手
# key: assistant_id (str), value: AssistantRead object
assistants_db: Dict[str, AssistantRead] = {}
# 添加默认助手 (确保 ID 与前端 Mock 一致)
default_assistant = AssistantRead(
id='asst-default',
name='默认助手',
description='通用聊天助手',
avatar='🤖',
system_prompt='你是一个乐于助人的 AI 助手。',
model='gpt-3.5-turbo',
temperature=0.7
)
assistants_db[default_assistant.id] = default_assistant
from typing import List, Optional
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy import update as sqlalchemy_update, delete as sqlalchemy_delete
from app.db.models import AssistantModel
from app.models.pydantic_models import AssistantCreate, AssistantUpdate, AssistantRead
class AssistantService:
"""助手数据的 CRUD 服务"""
"""助手数据的 CRUD 服务 (数据库版)"""
def get_assistants(self) -> List[AssistantRead]:
async def get_assistants(self, db: AsyncSession) -> List[AssistantRead]:
"""获取所有助手"""
return list(assistants_db.values())
result = await db.execute(select(AssistantModel).order_by(AssistantModel.name))
assistants = result.scalars().all()
return [AssistantRead.model_validate(a) for a in assistants] # Use model_validate in Pydantic v2
def get_assistant(self, assistant_id: str) -> Optional[AssistantRead]:
async def get_assistant(self, db: AsyncSession, assistant_id: str) -> Optional[AssistantRead]:
"""根据 ID 获取单个助手"""
return assistants_db.get(assistant_id)
result = await db.execute(select(AssistantModel).filter(AssistantModel.id == assistant_id))
assistant = result.scalars().first()
return AssistantRead.model_validate(assistant) if assistant else None
def create_assistant(self, assistant_data: AssistantCreate) -> AssistantRead:
async def create_assistant(self, db: AsyncSession, assistant_data: AssistantCreate) -> AssistantRead:
"""创建新助手"""
new_id = f"asst-{uuid.uuid4()}" # 生成唯一 ID
new_assistant = AssistantRead(id=new_id, **assistant_data.model_dump())
assistants_db[new_id] = new_assistant
print(f"助手已创建: {new_id} - {new_assistant.name}")
return new_assistant
# 使用 Pydantic 模型创建 DB 模型实例
db_assistant = AssistantModel(**assistant_data.model_dump())
# ID will be generated by default in the model
db.add(db_assistant)
await db.flush() # Flush to get the generated ID and defaults
await db.refresh(db_assistant) # Refresh to load all attributes
print(f"助手已创建 (DB): {db_assistant.id} - {db_assistant.name}")
return AssistantRead.model_validate(db_assistant)
def update_assistant(self, assistant_id: str, assistant_data: AssistantUpdate) -> Optional[AssistantRead]:
async def update_assistant(self, db: AsyncSession, assistant_id: str, assistant_data: AssistantUpdate) -> Optional[AssistantRead]:
"""更新现有助手"""
existing_assistant = assistants_db.get(assistant_id)
if not existing_assistant:
return None
update_values = assistant_data.model_dump(exclude_unset=True)
if not update_values:
# If nothing to update, just fetch and return the existing one
return await self.get_assistant(db, assistant_id)
# 使用 Pydantic 的 model_copy 和 update 来更新字段
update_data = assistant_data.model_dump(exclude_unset=True) # 只获取设置了值的字段
if update_data:
updated_assistant = existing_assistant.model_copy(update=update_data)
assistants_db[assistant_id] = updated_assistant
print(f"助手已更新: {assistant_id}")
return updated_assistant
return existing_assistant # 如果没有更新任何字段,返回原始助手
# Execute update statement
stmt = (
sqlalchemy_update(AssistantModel)
.where(AssistantModel.id == assistant_id)
.values(**update_values)
.returning(AssistantModel) # Return the updated row
)
result = await db.execute(stmt)
updated_assistant = result.scalars().first()
def delete_assistant(self, assistant_id: str) -> bool:
if updated_assistant:
await db.flush()
await db.refresh(updated_assistant)
print(f"助手已更新 (DB): {assistant_id}")
return AssistantRead.model_validate(updated_assistant)
return None # Assistant not found
async def delete_assistant(self, db: AsyncSession, assistant_id: str) -> bool:
"""删除助手"""
if assistant_id in assistants_db:
# 添加逻辑:不允许删除默认助手
if assistant_id == 'asst-default':
# Prevent deleting default assistant
if assistant_id == 'asst-default': # Assuming 'asst-default' is a known ID
print("尝试删除默认助手 - 操作被阻止")
return False # 或者抛出特定异常
del assistants_db[assistant_id]
print(f"助手已删除: {assistant_id}")
# TODO: 在实际应用中,还需要删除关联的会话和消息
return False
stmt = sqlalchemy_delete(AssistantModel).where(AssistantModel.id == assistant_id)
result = await db.execute(stmt)
if result.rowcount > 0:
await db.flush()
print(f"助手已删除 (DB): {assistant_id}")
# Deletion of sessions/messages handled by cascade="all, delete-orphan"
return True
return False
# 创建服务实例
assistant_service_instance = AssistantService()

View File

@ -1,127 +1,137 @@
# File: backend/app/services/chat_service.py (更新)
# Description: 封装 LangChain 聊天逻辑 (支持助手配置和会话历史)
# File: backend/app/services/chat_service.py (Update with DB for history)
# Description: 封装 LangChain 聊天逻辑 (使用数据库存储和检索消息)
from langchain_openai import ChatOpenAI
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.output_parsers import StrOutputParser
from langchain_core.messages import HumanMessage, AIMessage, BaseMessage, SystemMessage
from typing import Dict, List, Optional
from app.services.assistant_service import assistant_service_instance # 获取助手配置
from app.models.pydantic_models import AssistantRead # 引入助手模型
import app.core.config as Config
# --- 更新内存管理 ---
# 使用字典存储不同会话的内存
# key: session_id (str), value: List[BaseMessage]
chat_history_db: Dict[str, List[BaseMessage]] = {}
from typing import Dict, List, Optional, Tuple
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.db.models import MessageModel, AssistantModel # Import DB models
from app.services.assistant_service import AssistantService # Use class directly
from app.models.pydantic_models import AssistantRead
class ChatService:
"""处理 AI 聊天交互的服务 (支持助手配置)"""
"""处理 AI 聊天交互的服务 (使用数据库历史)"""
def __init__(self, default_api_key: str):
"""初始化时可传入默认 API Key"""
self.default_api_key = default_api_key
# 不再在 init 中创建固定的 LLM 和 chain
self.assistant_service = AssistantService() # Instantiate assistant service
def _get_llm(self, assistant: AssistantRead) -> ChatOpenAI:
"""根据助手配置动态创建 LLM 实例"""
# TODO: 支持不同模型提供商 (Gemini, Anthropic etc.)
# ... (remains the same) ...
if assistant.model.startswith("gpt"):
return ChatOpenAI(
model=assistant.model,
api_key=self.default_api_key, # 或从助手配置中读取特定 key
temperature=assistant.temperature
)
return ChatOpenAI(model=assistant.model, api_key=self.default_api_key, temperature=assistant.temperature)
elif assistant.model.startswith("gemini"):
from langchain_google_genai import ChatGoogleGenerativeAI
return ChatGoogleGenerativeAI(
model=assistant.model,
api_key=self.default_api_key, # 或从助手配置中读取特定 key
temperature=assistant.temperature
)
else:
# 默认或抛出错误
print(f"警告: 模型 {assistant.model} 未明确支持,尝试使用 ChatOpenAI")
return ChatOpenAI(
model=assistant.model,
api_key=self.default_api_key,
temperature=assistant.temperature
)
return ChatOpenAI(model=assistant.model, api_key=self.default_api_key, temperature=assistant.temperature)
async def get_ai_reply(self, user_message: str, session_id: str, assistant_id: str) -> str:
"""
获取 AI 对用户消息的回复 (使用指定助手和会话历史)
Args:
user_message (str): 用户发送的消息
session_id (str): 会话 ID
assistant_id (str): 使用的助手 ID
Returns:
str: AI 的回复文本
Raises:
ValueError: 如果找不到指定的助手
Exception: 如果调用 AI 服务时发生错误
"""
async def _get_chat_history(self, db: AsyncSession, session_id: str, limit: int = 10) -> Tuple[List[BaseMessage], int]:
"""从数据库加载指定会话的历史消息 (按 order 排序)"""
stmt = (
select(MessageModel)
.filter(MessageModel.session_id == session_id)
.order_by(MessageModel.order.desc()) # Get latest first
.limit(limit)
)
result = await db.execute(stmt)
db_messages = result.scalars().all()
# Convert to LangChain messages (in correct order: oldest to newest)
history: List[BaseMessage] = []
max_order = 0
for msg in reversed(db_messages): # Reverse to get chronological order
if msg.sender == 'user':
history.append(HumanMessage(content=msg.text))
elif msg.sender == 'ai':
history.append(AIMessage(content=msg.text))
# Add handling for 'system' if needed
max_order = max(max_order, msg.order) # Keep track of the latest order number
return history, max_order
async def _save_message(self, db: AsyncSession, session_id: str, sender: str, text: str, order: int):
"""将消息保存到数据库"""
db_message = MessageModel(
session_id=session_id,
sender=sender,
text=text,
order=order
)
db.add(db_message)
await db.flush() # Ensure it's added before potential commit
print(f"消息已保存 (DB): Session={session_id}, Order={order}, Sender={sender}")
async def get_ai_reply(self, db: AsyncSession, user_message: str, session_id: str, assistant_id: str) -> str:
"""获取 AI 回复,并保存用户消息和 AI 回复到数据库"""
# 1. 获取助手配置
assistant = assistant_service_instance.get_assistant(assistant_id)
assistant = await self.assistant_service.get_assistant(db, assistant_id)
if not assistant:
raise ValueError(f"找不到助手 ID: {assistant_id}")
# 2. 获取或初始化当前会话的历史记录
current_chat_history = chat_history_db.get(session_id, [])
# 2. 获取历史记录和下一个序号
current_chat_history, last_order = await self._get_chat_history(db, session_id)
user_message_order = last_order + 1
ai_message_order = last_order + 2
# 3. 构建 Prompt (包含动态系统提示)
# 3. 构建 Prompt
prompt = ChatPromptTemplate.from_messages([
SystemMessage(content=assistant.system_prompt), # 使用助手的系统提示
SystemMessage(content=assistant.system_prompt),
MessagesPlaceholder(variable_name="chat_history"),
HumanMessage(content="{input}"),
HumanMessage(content=user_message),
])
# 4. 获取 LLM 实例
# 4. 获取 LLM
llm = self._get_llm(assistant)
# 5. 定义输出解析器
output_parser = StrOutputParser()
# 6. 构建 LCEL 链
chain = prompt | llm | output_parser
try:
# 7. 调用链获取回复
# --- Save user message BEFORE calling LLM ---
await self._save_message(db, session_id, 'user', user_message, user_message_order)
# 5. 调用链获取回复
ai_response = await chain.ainvoke({
"input": user_message,
"chat_history": current_chat_history,
"chat_history": current_chat_history, # Pass history fetched from DB
})
# 8. 更新会话历史记录
current_chat_history.append(HumanMessage(content=user_message))
current_chat_history.append(AIMessage(content=ai_response))
# 限制历史记录长度 (例如最近 10 条消息)
max_history_length = 10
if len(current_chat_history) > max_history_length:
chat_history_db[session_id] = current_chat_history[-max_history_length:]
else:
chat_history_db[session_id] = current_chat_history
# --- Save AI response AFTER getting it ---
await self._save_message(db, session_id, 'ai', ai_response, ai_message_order)
# Note: We don't need to manage history in memory anymore (chat_history_db removed)
return ai_response
except Exception as e:
# Consider rolling back the user message save if LLM call fails,
# although often it's better to keep the user message.
# await db.rollback() # Handled by get_db_session dependency on error
print(f"调用 LangChain 时出错 (助手: {assistant_id}, 会话: {session_id}): {e}")
raise Exception(f"AI 服务暂时不可用: {e}")
# (可选) 添加一个简单的文本生成方法用于生成标题
async def generate_text(self, prompt_text: str, model_name: str = "gpt-3.5-turbo", temperature: float = 0.5) -> str:
"""使用指定模型生成文本 (用于标题等)"""
async def generate_text(self, prompt_text: str, model_name: str = "gemini-2.0-flash", temperature: float = 0.5) -> str:
# ... (remains the same) ...
try:
# 使用一个临时的、可能更便宜的模型
temp_llm = ChatOpenAI(model=model_name, api_key=self.default_api_key, temperature=temperature)
temp_llm = ChatGoogleGenerativeAI(
model=model_name,
api_key=self.default_api_key, # 或从助手配置中读取特定 key
temperature=temperature
)
response = await temp_llm.ainvoke(prompt_text)
return response.content
except Exception as e:
print(f"生成文本时出错: {e}")
return "无法生成标题" # 返回默认值或抛出异常
return "无法生成标题"
# --- 创建 ChatService 实例 ---
if not Config.GOOGLE_API_KEY:
raise ValueError("请在 .env 文件中设置 OPENAI_API_KEY")
chat_service_instance = ChatService(default_api_key=Config.GOOGLE_API_KEY)
# ChatService instance is now created where needed or injected, no global instance here.

View File

@ -1,82 +1,74 @@
# File: backend/app/services/session_service.py (新建)
# Description: 管理会话数据的服务 (内存实现)
# File: backend/app/services/session_service.py (Update with DB)
# Description: 管理会话数据的服务 (使用 SQLAlchemy)
from typing import Dict, List, Optional
from app.models.pydantic_models import SessionRead, SessionCreateRequest, SessionCreateResponse, AssistantRead
from app.services.assistant_service import assistant_service_instance # 需要获取助手信息
from typing import List, Optional
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy import delete as sqlalchemy_delete
from app.db.models import SessionModel, AssistantModel # Import DB models
from app.models.pydantic_models import SessionRead, SessionCreateRequest, SessionCreateResponse
from datetime import datetime, timezone
import uuid
# 导入 ChatService 以调用 LLM 生成标题 (避免循环导入,考虑重构)
# from app.services.chat_service import chat_service_instance
# 内存数据库存储会话
# key: session_id (str), value: SessionRead object
sessions_db: Dict[str, SessionRead] = {}
# Import ChatService for title generation (consider refactoring later)
from app.services.chat_service import ChatService
import app.core.config as Config
chat_service_instance = ChatService(Config.GOOGLE_API_KEY)
class SessionService:
"""会话数据的 CRUD 及标题生成服务"""
"""会话数据的 CRUD 及标题生成服务 (数据库版)"""
async def create_session(self, session_data: SessionCreateRequest) -> SessionCreateResponse:
async def create_session(self, db: AsyncSession, session_data: SessionCreateRequest) -> SessionCreateResponse:
"""创建新会话并生成标题"""
assistant = assistant_service_instance.get_assistant(session_data.assistant_id)
# 检查助手是否存在
result = await db.execute(select(AssistantModel).filter(AssistantModel.id == session_data.assistant_id))
assistant = result.scalars().first()
if not assistant:
raise ValueError("指定的助手不存在")
new_id = f"session-{uuid.uuid4()}"
created_time = datetime.now(timezone.utc)
# --- 调用 LLM 生成标题 ---
try:
title_prompt = f"根据以下用户第一条消息为此对话生成一个简洁的标题不超过10个字:\n\n{session_data.first_message}"
generated_title = await chat_service_instance.generate_text(title_prompt)
except Exception as e:
print(f"生成会话标题时出错: {e}")
generated_title = f"关于 \"{session_data.first_message[:15]}...\"" # Fallback
# --- 生成结束 ---
# --- TODO: 调用 LLM 生成标题 ---
# title_prompt = f"根据以下用户第一条消息为此对话生成一个简洁的标题不超过10个字:\n\n{session_data.first_message}"
# generated_title = await chat_service_instance.generate_text(title_prompt) # 需要一个简单的文本生成方法
# 模拟标题生成
generated_title = f"关于 \"{session_data.first_message[:15]}...\""
print(f"为新会话 {new_id} 生成标题: {generated_title}")
# --- 模拟结束 ---
new_session = SessionRead(
id=new_id,
db_session = SessionModel(
title=generated_title,
assistant_id=session_data.assistant_id,
created_at=created_time.isoformat() # 存储 ISO 格式字符串
assistant_id=session_data.assistant_id
# ID and created_at have defaults
)
sessions_db[new_id] = new_session
print(f"会话已创建: {new_id}")
db.add(db_session)
await db.flush()
await db.refresh(db_session)
print(f"会话已创建 (DB): {db_session.id}")
return SessionCreateResponse(
id=new_session.id,
title=new_session.title,
assistant_id=new_session.assistant_id,
created_at=new_session.created_at
id=db_session.id,
title=db_session.title,
assistant_id=db_session.assistant_id,
created_at=db_session.created_at.isoformat() # Use datetime from DB model
)
def get_sessions_by_assistant(self, assistant_id: str) -> List[SessionRead]:
async def get_sessions_by_assistant(self, db: AsyncSession, assistant_id: str) -> List[SessionRead]:
"""获取指定助手的所有会话"""
return [s for s in sessions_db.values() if s.assistant_id == assistant_id]
stmt = select(SessionModel).filter(SessionModel.assistant_id == assistant_id).order_by(SessionModel.updated_at.desc()) # Order by update time
result = await db.execute(stmt)
sessions = result.scalars().all()
return [SessionRead.model_validate(s) for s in sessions]
def get_session(self, session_id: str) -> Optional[SessionRead]:
async def get_session(self, db: AsyncSession, session_id: str) -> Optional[SessionRead]:
"""获取单个会话"""
return sessions_db.get(session_id)
result = await db.execute(select(SessionModel).filter(SessionModel.id == session_id))
session = result.scalars().first()
return SessionRead.model_validate(session) if session else None
def delete_session(self, session_id: str) -> bool:
async def delete_session(self, db: AsyncSession, session_id: str) -> bool:
"""删除会话"""
if session_id in sessions_db:
del sessions_db[session_id]
print(f"会话已删除: {session_id}")
# TODO: 删除关联的消息
stmt = sqlalchemy_delete(SessionModel).where(SessionModel.id == session_id)
result = await db.execute(stmt)
if result.rowcount > 0:
await db.flush()
print(f"会话已删除 (DB): {session_id}")
# Deletion of messages handled by cascade
return True
return False
def delete_sessions_by_assistant(self, assistant_id: str) -> int:
"""删除指定助手的所有会话"""
ids_to_delete = [s.id for s in sessions_db.values() if s.assistant_id == assistant_id]
count = 0
for session_id in ids_to_delete:
if self.delete_session(session_id):
count += 1
print(f"删除了助手 {assistant_id}{count} 个会话")
return count
# 创建服务实例
session_service_instance = SessionService()

BIN
backend/cherryai.db Normal file

Binary file not shown.

File diff suppressed because it is too large Load Diff

View File

@ -45,7 +45,7 @@ export default function RootLayout({
<html lang="zh-CN">
<body className={`${geistSans.variable} ${geistMono.variable} antialiased flex h-screen bg-gray-100 dark:bg-gray-900`}>
{/* 侧边栏导航 */}
<aside className="w-16 bg-white dark:bg-gray-800 p-3 shadow-md flex flex-col items-center mr-1"> {/* 调整内边距和对齐 */}
<aside className="w-16 bg-white dark:bg-gray-800 p-3 shadow-md flex flex-col items-center rounded-lg"> {/* 调整内边距和对齐 */}
{/* Logo */}
<div className="mb-6"> {/* 调整 Logo 边距 */}
<Link href="/" className="flex items-center justify-center text-3xl font-bold text-red-600 dark:text-red-500" title="CherryAI 主页">
@ -60,12 +60,11 @@ export default function RootLayout({
<Link
href={item.href}
className="relative flex items-center justify-center p-2 rounded-lg text-gray-600 dark:text-gray-400 hover:bg-red-100 dark:hover:bg-red-900/50 hover:text-red-700 dark:hover:text-red-400 transition-colors duration-200 group" // 居中图标
title={item.name} // 保留原生 title
>
<item.icon className="h-6 w-6 flex-shrink-0 text-gray-600 dark:text-gray-400" />
{/* Tooltip 文字标签 */}
<span
className="absolute left-full top-1/2 -translate-y-1/2 ml-3 px-2 py-1 bg-gray-900 dark:bg-gray-700 text-white text-xl rounded shadow-lg opacity-0 group-hover:opacity-100 transition-opacity duration-200 delay-150 whitespace-nowrap pointer-events-none" // 使用 pointer-events-none 避免干扰悬浮
className="z-10 absolute left-full top-1/2 -translate-y-1/2 ml-3 px-2 py-1 bg-gray-900 dark:bg-gray-700 text-white text-xl rounded shadow-lg opacity-0 group-hover:opacity-100 transition-opacity duration-200 delay-150 whitespace-nowrap pointer-events-none" // 使用 pointer-events-none 避免干扰悬浮
>
{item.name}
</span>

View File

@ -1,40 +1,28 @@
// File: frontend/lib/api.ts (更新)
// Description: 添加调用助手和会话管理 API 的函数
import { Assistant, AssistantCreateData, AssistantUpdateData } from "@/types/assistant";
import axios from "axios";
// --- Types (从后端模型同步或手动定义) ---
// 这些类型应该与后端 pydantic_models.py 中的 Read 模型匹配
export interface Assistant {
id: string;
name: string;
description?: string | null;
avatar?: string | null;
system_prompt: string;
model: string;
temperature: number;
}
// --- Types ---
export interface Session {
id: string;
title: string;
assistant_id: string;
created_at: string; // ISO date string
updated_at?: string | null; // Add updated_at
}
// 创建助手时发送的数据类型
export interface AssistantCreateData {
name: string;
description?: string | null;
avatar?: string | null;
system_prompt: string;
model: string;
temperature: number;
// Message type from backend
export interface Message {
id: string;
session_id: string;
sender: 'user' | 'ai'; // Or extend with 'system' if needed
text: string;
order: number;
created_at: string; // ISO date string
}
// 更新助手时发送的数据类型 (所有字段可选)
export type AssistantUpdateData = Partial<AssistantCreateData>;
// 聊天响应类型
export interface ChatApiResponse {
reply: string;
@ -169,4 +157,19 @@ export const deleteSession = async (sessionId: string): Promise<void> => {
//当前端发送 sessionId 为 'temp-new-chat' 的消息时,后端会自动创建。
//如果需要单独创建会话(例如,不发送消息就创建),则需要单独实现前端调用 POST /sessions/。
// TODO: 添加获取会话消息的 API 函数 (GET /sessions/{session_id}/messages)
// --- Message API (New) ---
/** 获取指定会话的消息列表 */
export const getMessagesBySession = async (sessionId: string, limit: number = 100, skip: number = 0): Promise<Message[]> => {
try {
const response = await apiClient.get<Message[]>(`/messages/session/${sessionId}`, {
params: { limit, skip }
});
return response.data;
} catch (error) {
// Handle 404 specifically if needed (session exists but no messages)
if (axios.isAxiosError(error) && error.response?.status === 404) {
return []; // Return empty list if session not found or no messages
}
throw new Error(handleApiError(error, 'getMessagesBySession'));
}
};

View File

@ -0,0 +1,23 @@
// --- Types (从后端模型同步或手动定义) ---
// 这些类型应该与后端 pydantic_models.py 中的 Read 模型匹配
export interface Assistant {
id: string;
name: string;
description?: string | null;
avatar?: string | null;
system_prompt: string;
model: string;
temperature: number;
}
// 创建助手时发送的数据类型
export interface AssistantCreateData {
name: string;
description?: string | null;
avatar?: string | null;
system_prompt: string;
model: string;
temperature: number;
}
// 更新助手时发送的数据类型 (所有字段可选)
export type AssistantUpdateData = Partial<AssistantCreateData>;