Compare commits

...

3 Commits

Author SHA1 Message Date
f0863914c2 添加数据管理 2025-04-30 04:39:36 +08:00
7df10e82be 加入会话管理和助手管理 2025-04-30 02:23:00 +08:00
194282e029 add shadcn ui 2025-04-30 00:52:29 +08:00
33 changed files with 3584 additions and 381 deletions

View File

@ -1,15 +1,12 @@
# File: backend/app/api/v1/api.py
# File: backend/app/api/v1/api.py (Update)
# Description: 聚合 v1 版本的所有 API 路由
from fastapi import APIRouter
from app.api.v1.endpoints import chat # 导入聊天路由
from app.api.v1.endpoints import chat, assistants, sessions, messages # Import messages router
# 创建 v1 版本的总路由
api_router = APIRouter()
# 将聊天路由包含到 v1 总路由中,并添加前缀
api_router.include_router(chat.router, prefix="/chat", tags=["chat"])
# --- 如果有其他路由,也在这里 include ---
# from app.api.v1.endpoints import workflow
# api_router.include_router(workflow.router, prefix="/workflow", tags=["workflow"])
api_router.include_router(chat.router, prefix="/chat", tags=["Chat"])
api_router.include_router(assistants.router, prefix="/assistants", tags=["Assistants"])
api_router.include_router(sessions.router, prefix="/sessions", tags=["Sessions"])
api_router.include_router(messages.router, prefix="/messages", tags=["Messages"]) # Add messages router

View File

@ -0,0 +1,68 @@
# File: backend/app/api/v1/endpoints/assistants.py (Update with DB session dependency)
# Description: 助手的 API 路由 (使用数据库会话)
from fastapi import APIRouter, HTTPException, Depends, status
from typing import List
from sqlalchemy.ext.asyncio import AsyncSession
from app.db.database import get_db_session # Import DB session dependency
from app.models.pydantic_models import AssistantRead, AssistantCreate, AssistantUpdate
from app.services.assistant_service import AssistantService # Import the class
router = APIRouter()
# --- Dependency Injection for Service and DB Session ---
# Service instance can be created per request or globally
# For simplicity, let's create it here, but pass db session to methods
assistant_service = AssistantService()
@router.post("/", response_model=AssistantRead, status_code=status.HTTP_201_CREATED)
async def create_new_assistant(
assistant_data: AssistantCreate,
db: AsyncSession = Depends(get_db_session) # Inject DB session
):
return await assistant_service.create_assistant(db, assistant_data)
@router.get("/", response_model=List[AssistantRead])
async def read_all_assistants(
db: AsyncSession = Depends(get_db_session)
):
return await assistant_service.get_assistants(db)
@router.get("/{assistant_id}", response_model=AssistantRead)
async def read_assistant_by_id(
assistant_id: str,
db: AsyncSession = Depends(get_db_session)
):
assistant = await assistant_service.get_assistant(db, assistant_id)
if not assistant:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="找不到指定的助手")
return assistant
@router.put("/{assistant_id}", response_model=AssistantRead)
async def update_existing_assistant(
assistant_id: str,
assistant_data: AssistantUpdate,
db: AsyncSession = Depends(get_db_session)
):
updated_assistant = await assistant_service.update_assistant(db, assistant_id, assistant_data)
if not updated_assistant:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="找不到指定的助手")
return updated_assistant
@router.delete("/{assistant_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_existing_assistant(
assistant_id: str,
db: AsyncSession = Depends(get_db_session)
):
# Handle potential error from service if trying to delete default
try:
deleted = await assistant_service.delete_assistant(db, assistant_id)
if not deleted:
# Check if it exists to differentiate 404 from 403 (or handle in service)
assistant = await assistant_service.get_assistant(db, assistant_id)
if assistant and assistant_id == 'asst-default':
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="不允许删除默认助手")
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="找不到指定的助手")
except Exception as e: # Catch other potential DB errors
print(f"删除助手时出错: {e}")
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="删除助手失败")

View File

@ -1,41 +1,68 @@
# File: backend/app/api/v1/endpoints/chat.py (更新)
# Description: 聊天功能的 API 路由 (使用 ChatService)
# File: backend/app/api/v1/endpoints/chat.py (Update with DB session dependency)
# Description: 聊天功能的 API 路由 (使用数据库会话)
from fastapi import APIRouter, HTTPException, Depends
from app.models.pydantic_models import ChatRequest, ChatResponse
# 导入 ChatService 实例
from app.services.chat_service import chat_service_instance, ChatService
from fastapi import APIRouter, HTTPException, Depends, status
from sqlalchemy.ext.asyncio import AsyncSession
from app.db.database import get_db_session
from app.models.pydantic_models import ChatRequest, ChatResponse, SessionCreateRequest
from app.services.chat_service import ChatService # Import class
from app.services.session_service import SessionService # Import class
import app.core.config as Config # Import API Key for ChatService instantiation
router = APIRouter()
# --- (可选) 使用 FastAPI 的依赖注入来获取 ChatService 实例 ---
# 这样更符合 FastAPI 的风格,方便测试和替换实现
# async def get_chat_service() -> ChatService:
# return chat_service_instance
# --- Dependency Injection ---
# Instantiate services here or use a more sophisticated dependency injection system
chat_service = ChatService(default_api_key=Config.GOOGLE_API_KEY)
session_service = SessionService()
@router.post("/", response_model=ChatResponse)
async def handle_chat_message(
request: ChatRequest,
# chat_service: ChatService = Depends(get_chat_service) # 使用依赖注入
db: AsyncSession = Depends(get_db_session) # Inject DB session
):
"""
处理用户发送的聊天消息并使用 LangChain 获取 AI 回复
"""
user_message = request.message
# session_id = request.session_id # 如果 ChatRequest 中包含 session_id
print(f"接收到用户消息: {user_message}")
user_message = request.message
session_id = request.session_id
assistant_id = request.assistant_id
try:
# --- 调用 ChatService 获取 AI 回复 ---
# 使用全局实例 (简单方式)
ai_reply = await chat_service_instance.get_ai_reply(user_message)
# 或者使用依赖注入获取的实例
# ai_reply = await chat_service.get_ai_reply(user_message, session_id)
print(f"接收到消息: User='{user_message}', Session='{session_id}', Assistant='{assistant_id}'")
print(f"发送 AI 回复: {ai_reply}")
return ChatResponse(reply=ai_reply)
response_session_id = None
response_session_title = None
except Exception as e:
# 如果 ChatService 抛出异常,捕获并返回 HTTP 500 错误
print(f"处理聊天消息时发生错误: {e}")
raise HTTPException(status_code=500, detail=str(e))
if session_id == 'temp-new-chat':
print("检测到临时新会话,正在创建...")
try:
create_req = SessionCreateRequest(assistant_id=assistant_id, first_message=user_message)
# Pass db session to the service method
created_session = await session_service.create_session(db, create_req)
session_id = created_session.id
response_session_id = created_session.id
response_session_title = created_session.title
print(f"新会话已创建: ID={session_id}, Title='{created_session.title}'")
except ValueError as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
except Exception as e:
print(f"创建会话时出错: {e}")
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="创建会话失败")
try:
# Pass db session to the service method
ai_reply = await chat_service.get_ai_reply(
db=db,
user_message=user_message,
session_id=session_id,
assistant_id=assistant_id
)
print(f"发送 AI 回复: '{ai_reply}'")
return ChatResponse(
reply=ai_reply,
session_id=response_session_id,
session_title=response_session_title
)
except ValueError as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
except Exception as e:
print(f"处理聊天消息时发生错误: {e}")
# The get_db_session dependency will handle rollback
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e))

View File

@ -0,0 +1,33 @@
# File: backend/app/api/v1/endpoints/messages.py (New)
# Description: API endpoint for fetching messages
from fastapi import APIRouter, Depends, HTTPException, status, Query
from typing import List
from sqlalchemy.ext.asyncio import AsyncSession
from app.db.database import get_db_session
from app.models.pydantic_models import MessageRead
from app.db.models import MessageModel # Import DB model
from sqlalchemy.future import select
router = APIRouter()
@router.get("/session/{session_id}", response_model=List[MessageRead])
async def read_messages_for_session(
session_id: str,
db: AsyncSession = Depends(get_db_session),
skip: int = Query(0, ge=0), # Offset for pagination
limit: int = Query(100, ge=1, le=500) # Limit number of messages
):
"""获取指定会话的消息列表 (按时间顺序)"""
# TODO: Add check if session exists
stmt = (
select(MessageModel)
.filter(MessageModel.session_id == session_id)
.order_by(MessageModel.order.asc()) # Fetch in chronological order
.offset(skip)
.limit(limit)
)
result = await db.execute(stmt)
messages = result.scalars().all()
# Validate using Pydantic model before returning
return [MessageRead.model_validate(msg) for msg in messages]

View File

@ -0,0 +1,52 @@
# File: backend/app/api/v1/endpoints/sessions.py (Update with DB session dependency)
# Description: 会话管理的 API 路由 (使用数据库会话)
from fastapi import APIRouter, HTTPException, Depends, status
from typing import List
from sqlalchemy.ext.asyncio import AsyncSession
from app.db.database import get_db_session
from app.models.pydantic_models import SessionRead, SessionCreateRequest, SessionCreateResponse
from app.services.session_service import SessionService # Import the class
router = APIRouter()
session_service = SessionService() # Create instance
@router.post("/", response_model=SessionCreateResponse, status_code=status.HTTP_201_CREATED)
async def create_new_session(
session_data: SessionCreateRequest,
db: AsyncSession = Depends(get_db_session) # Inject DB session
):
try:
return await session_service.create_session(db, session_data)
except ValueError as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
except Exception as e:
print(f"创建会话时出错: {e}")
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="创建会话失败")
@router.get("/assistant/{assistant_id}", response_model=List[SessionRead])
async def read_sessions_for_assistant(
assistant_id: str,
db: AsyncSession = Depends(get_db_session)
):
# Consider adding check if assistant exists first
return await session_service.get_sessions_by_assistant(db, assistant_id)
@router.get("/{session_id}", response_model=SessionRead)
async def read_session_by_id(
session_id: str,
db: AsyncSession = Depends(get_db_session)
):
session = await session_service.get_session(db, session_id)
if not session:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="找不到指定的会话")
return session
@router.delete("/{session_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_existing_session(
session_id: str,
db: AsyncSession = Depends(get_db_session)
):
deleted = await session_service.delete_session(db, session_id)
if not deleted:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="找不到指定的会话")

View File

@ -12,4 +12,6 @@ load_dotenv()
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY") # 如果使用 Google
# 可以在这里添加其他配置项
# Define the database URL (SQLite in this case)
# DATABASE_URL = "sqlite+aiosqlite:///./cherryai.db" # Use async driver
DATABASE_URL = os.getenv("DATABASE_URL", "sqlite+aiosqlite:///./cherryai.db")

View File

@ -0,0 +1,43 @@
# File: backend/app/db/database.py (New - Database setup)
# Description: SQLAlchemy 数据库引擎和会话设置
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
from sqlalchemy.orm import sessionmaker, declarative_base
from app.core.config import DATABASE_URL
# 创建异步数据库引擎
# connect_args={"check_same_thread": False} is needed only for SQLite.
# It's not needed for other databases.
engine = create_async_engine(DATABASE_URL, echo=True, connect_args={"check_same_thread": False})
# 创建异步会话工厂
# expire_on_commit=False prevents attributes from expiring after commit.
AsyncSessionFactory = sessionmaker(
bind=engine,
class_=AsyncSession,
expire_on_commit=False,
autocommit=False,
autoflush=False,
)
# 创建数据模型的基础类
Base = declarative_base()
# --- Dependency to get DB session ---
async def get_db_session() -> AsyncSession:
"""FastAPI dependency to get an async database session."""
async with AsyncSessionFactory() as session:
try:
yield session
await session.commit() # Commit transaction if successful
except Exception:
await session.rollback() # Rollback on error
raise
finally:
await session.close()
# --- Function to create tables (call this on startup) ---
async def create_db_and_tables():
async with engine.begin() as conn:
# await conn.run_sync(Base.metadata.drop_all) # Use drop_all carefully in dev
await conn.run_sync(Base.metadata.create_all)

52
backend/app/db/models.py Normal file
View File

@ -0,0 +1,52 @@
# File: backend/app/db/models.py (New - Database models)
# Description: SQLAlchemy ORM 模型定义
from sqlalchemy import Column, String, Float, ForeignKey, Text, DateTime, Integer
from sqlalchemy.orm import relationship
from sqlalchemy.sql import func # For default timestamps
from app.db.database import Base
import uuid
from datetime import datetime, timezone
def generate_uuid():
return str(uuid.uuid4())
class AssistantModel(Base):
__tablename__ = "assistants"
id = Column(String, primary_key=True, default=generate_uuid)
name = Column(String(50), nullable=False, index=True)
description = Column(String(200), nullable=True)
avatar = Column(String(5), nullable=True)
system_prompt = Column(Text, nullable=False)
model = Column(String, nullable=False)
temperature = Column(Float, nullable=False, default=0.7)
created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
sessions = relationship("SessionModel", back_populates="assistant", cascade="all, delete-orphan")
class SessionModel(Base):
__tablename__ = "sessions"
id = Column(String, primary_key=True, default=generate_uuid)
title = Column(String(100), nullable=False, default="New Chat")
assistant_id = Column(String, ForeignKey("assistants.id"), nullable=False, index=True)
created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(DateTime(timezone=True), onupdate=func.now(), index=True) # Index for sorting
assistant = relationship("AssistantModel", back_populates="sessions")
messages = relationship("MessageModel", back_populates="session", cascade="all, delete-orphan", order_by="MessageModel.created_at") # Order messages by time
class MessageModel(Base):
__tablename__ = "messages"
id = Column(String, primary_key=True, default=generate_uuid)
session_id = Column(String, ForeignKey("sessions.id"), nullable=False, index=True)
sender = Column(String(10), nullable=False) # 'user' or 'ai' or 'system'
text = Column(Text, nullable=False)
order = Column(Integer, nullable=False) # Explicit order within session
created_at = Column(DateTime(timezone=True), server_default=func.now())
session = relationship("SessionModel", back_populates="messages")

View File

@ -1,20 +1,31 @@
# File: backend/app/main.py (确认 load_dotenv 调用位置)
# Description: FastAPI 应用入口
# File: backend/app/main.py (Update - Add startup event)
# Description: FastAPI 应用入口 (添加数据库初始化)
from fastapi import FastAPI
from starlette.middleware.cors import CORSMiddleware
from app.api.v1.api import api_router as api_router_v1
# 确保在创建 FastAPI 实例之前加载环境变量
from app.core.config import OPENAI_API_KEY # 导入会触发 load_dotenv
import app.core.config # Ensure config is loaded
from app.db.database import create_db_and_tables # Import table creation function
from contextlib import asynccontextmanager
# 创建 FastAPI 应用实例
app = FastAPI(title="CherryAI Backend", version="0.1.0")
# --- Lifespan context manager for startup/shutdown events ---
@asynccontextmanager
async def lifespan(app: FastAPI):
# Startup actions
print("应用程序启动中...")
await create_db_and_tables() # Create database tables on startup
print("数据库表已检查/创建。")
# You can add the default assistant creation here if needed,
# but doing it in the service/model definition might be simpler for defaults.
yield
# Shutdown actions
print("应用程序关闭中...")
# --- 配置 CORS ---
origins = [
"http://localhost:3000",
"http://127.0.0.1:3000",
]
# Create FastAPI app with lifespan context manager
app = FastAPI(title="CherryAI Backend", version="0.1.0", lifespan=lifespan)
# --- CORS Middleware ---
origins = [ "http://localhost:3000", "http://127.0.0.1:3000" ]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
@ -23,10 +34,10 @@ app.add_middleware(
allow_headers=["*"],
)
# --- 挂载 API 路由 ---
# --- API Routers ---
app.include_router(api_router_v1, prefix="/api/v1")
# --- 根路径 ---
# --- Root Endpoint ---
@app.get("/", tags=["Root"])
async def read_root():
return {"message": "欢迎来到 CherryAI 后端!"}
return {"message": "欢迎来到 CherryAI 后端!"}

View File

@ -1,16 +1,93 @@
# File: backend/app/models/pydantic_models.py
# File: backend/app/models/pydantic_models.py (Update Read models, add Message models)
# Description: Pydantic 模型定义 API 数据结构
from pydantic import BaseModel
from pydantic import BaseModel, Field
from typing import Optional, List
import uuid
from datetime import datetime # Use datetime directly
# --- Assistant Models ---
class AssistantBase(BaseModel):
"""助手的基础模型,包含通用字段"""
name: str = Field(..., min_length=1, max_length=50, description="助手名称")
description: Optional[str] = Field(None, max_length=200, description="助手描述")
avatar: Optional[str] = Field(None, max_length=5, description="头像 Emoji 或字符")
system_prompt: str = Field(..., min_length=1, max_length=4000, description="系统提示")
model: str = Field(..., description="使用的 LLM 模型")
temperature: float = Field(0.7, ge=0.0, le=1.0, description="温度参数 (0.0-1.0)")
# 可以添加 top_p, max_tokens 等
class AssistantCreate(AssistantBase):
"""创建助手时使用的模型 (不需要 ID)"""
pass
class AssistantUpdate(BaseModel):
"""更新助手时使用的模型 (所有字段可选)"""
name: Optional[str] = Field(None, min_length=1, max_length=50)
description: Optional[str] = Field(None, max_length=200)
avatar: Optional[str] = Field(None, max_length=5)
system_prompt: Optional[str] = Field(None, min_length=1, max_length=4000)
model: Optional[str] = None
temperature: Optional[float] = Field(None, ge=0.0, le=1.0)
class AssistantRead(AssistantBase):
"""读取助手信息时返回的模型 (包含 ID)"""
id: str = Field(..., description="助手唯一 ID")
created_at: datetime # Add timestamps
updated_at: Optional[datetime] = None
class Config:
from_attributes = True # Pydantic v2: orm_mode = True
# --- Chat Models (更新) ---
class ChatRequest(BaseModel):
"""聊天请求模型"""
message: str
# 可以添加更多字段,如 user_id, session_id 等
"""聊天请求模型 (添加 sessionId 和 assistantId)"""
message: str
session_id: str = Field(..., description="当前会话 ID (可以是 'temp-new-chat')")
assistant_id: str = Field(..., description="当前使用的助手 ID")
class ChatResponse(BaseModel):
"""聊天响应模型"""
reply: str
# 可以添加更多字段,如 message_id, status 等
"""聊天响应模型"""
reply: str
session_id: Optional[str] = None # (可选) 如果创建了新会话,返回新 ID
session_title: Optional[str] = None # (可选) 如果创建了新会话,返回新标题
# --- 你可以在这里添加其他功能的模型 ---
# --- Session Models ---
class SessionCreateRequest(BaseModel):
"""创建会话请求模型"""
assistant_id: str
first_message: str # 用户的第一条消息,用于生成标题
class SessionCreateResponse(BaseModel):
"""创建会话响应模型"""
id: str
title: str
assistant_id: str
created_at: str # 返回 ISO 格式时间字符串
class SessionRead(BaseModel):
"""读取会话信息模型"""
id: str
title: str
assistant_id: str
created_at: datetime # Use datetime
updated_at: Optional[datetime] = None
class Config:
from_attributes = True
# --- Message Models (New) ---
class MessageBase(BaseModel):
sender: str # 'user' or 'ai'
text: str
class MessageRead(MessageBase):
id: str
session_id: str
order: int
created_at: datetime
class Config:
from_attributes = True

View File

@ -0,0 +1,76 @@
# File: backend/app/services/assistant_service.py (Update with DB)
# Description: 管理助手数据的服务 (使用 SQLAlchemy)
from typing import List, Optional
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy import update as sqlalchemy_update, delete as sqlalchemy_delete
from app.db.models import AssistantModel
from app.models.pydantic_models import AssistantCreate, AssistantUpdate, AssistantRead
class AssistantService:
"""助手数据的 CRUD 服务 (数据库版)"""
async def get_assistants(self, db: AsyncSession) -> List[AssistantRead]:
"""获取所有助手"""
result = await db.execute(select(AssistantModel).order_by(AssistantModel.name))
assistants = result.scalars().all()
return [AssistantRead.model_validate(a) for a in assistants] # Use model_validate in Pydantic v2
async def get_assistant(self, db: AsyncSession, assistant_id: str) -> Optional[AssistantRead]:
"""根据 ID 获取单个助手"""
result = await db.execute(select(AssistantModel).filter(AssistantModel.id == assistant_id))
assistant = result.scalars().first()
return AssistantRead.model_validate(assistant) if assistant else None
async def create_assistant(self, db: AsyncSession, assistant_data: AssistantCreate) -> AssistantRead:
"""创建新助手"""
# 使用 Pydantic 模型创建 DB 模型实例
db_assistant = AssistantModel(**assistant_data.model_dump())
# ID will be generated by default in the model
db.add(db_assistant)
await db.flush() # Flush to get the generated ID and defaults
await db.refresh(db_assistant) # Refresh to load all attributes
print(f"助手已创建 (DB): {db_assistant.id} - {db_assistant.name}")
return AssistantRead.model_validate(db_assistant)
async def update_assistant(self, db: AsyncSession, assistant_id: str, assistant_data: AssistantUpdate) -> Optional[AssistantRead]:
"""更新现有助手"""
update_values = assistant_data.model_dump(exclude_unset=True)
if not update_values:
# If nothing to update, just fetch and return the existing one
return await self.get_assistant(db, assistant_id)
# Execute update statement
stmt = (
sqlalchemy_update(AssistantModel)
.where(AssistantModel.id == assistant_id)
.values(**update_values)
.returning(AssistantModel) # Return the updated row
)
result = await db.execute(stmt)
updated_assistant = result.scalars().first()
if updated_assistant:
await db.flush()
await db.refresh(updated_assistant)
print(f"助手已更新 (DB): {assistant_id}")
return AssistantRead.model_validate(updated_assistant)
return None # Assistant not found
async def delete_assistant(self, db: AsyncSession, assistant_id: str) -> bool:
"""删除助手"""
# Prevent deleting default assistant
if assistant_id == 'asst-default': # Assuming 'asst-default' is a known ID
print("尝试删除默认助手 - 操作被阻止")
return False
stmt = sqlalchemy_delete(AssistantModel).where(AssistantModel.id == assistant_id)
result = await db.execute(stmt)
if result.rowcount > 0:
await db.flush()
print(f"助手已删除 (DB): {assistant_id}")
# Deletion of sessions/messages handled by cascade="all, delete-orphan"
return True
return False

View File

@ -1,92 +1,137 @@
# File: backend/app/services/chat_service.py (新建)
# Description: 封装 LangChain 聊天逻辑
# File: backend/app/services/chat_service.py (Update with DB for history)
# Description: 封装 LangChain 聊天逻辑 (使用数据库存储和检索消息)
from langchain_openai import ChatOpenAI
from langchain_google_genai import ChatGoogleGenerativeAI # 如果使用 Google
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.output_parsers import StrOutputParser
from langchain_core.messages import HumanMessage, AIMessage
# --- 可选:添加内存管理 ---
# 简单的内存实现 (可以替换为更复杂的 LangChain Memory 类)
chat_history = {} # 使用字典存储不同会话的内存,需要 session_id
from langchain_core.messages import HumanMessage, AIMessage, BaseMessage, SystemMessage
from typing import Dict, List, Optional, Tuple
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from app.db.models import MessageModel, AssistantModel # Import DB models
from app.services.assistant_service import AssistantService # Use class directly
from app.models.pydantic_models import AssistantRead
class ChatService:
"""处理 AI 聊天交互的服务"""
"""处理 AI 聊天交互的服务 (使用数据库历史)"""
def __init__(self, api_key: str):
"""
初始化 ChatService
Args:
api_key (str): 用于 LLM API 密钥
"""
# --- 选择并初始化 LLM ---
# 使用 OpenAI GPT-3.5 Turbo (推荐) 或 GPT-4
# self.llm = ChatOpenAI(model="gpt-3.5-turbo", api_key=api_key, temperature=0.7)
# self.llm = ChatOpenAI(model="gpt-4", api_key=api_key, temperature=0.7)
def __init__(self, default_api_key: str):
self.default_api_key = default_api_key
self.assistant_service = AssistantService() # Instantiate assistant service
# --- 如果使用 Google Gemini ---
self.llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", google_api_key=api_key, convert_system_message_to_human=True)
def _get_llm(self, assistant: AssistantRead) -> ChatOpenAI:
# ... (remains the same) ...
if assistant.model.startswith("gpt"):
return ChatOpenAI(model=assistant.model, api_key=self.default_api_key, temperature=assistant.temperature)
elif assistant.model.startswith("gemini"):
return ChatGoogleGenerativeAI(
model=assistant.model,
api_key=self.default_api_key, # 或从助手配置中读取特定 key
temperature=assistant.temperature
)
else:
print(f"警告: 模型 {assistant.model} 未明确支持,尝试使用 ChatOpenAI")
return ChatOpenAI(model=assistant.model, api_key=self.default_api_key, temperature=assistant.temperature)
# --- 定义 Prompt 模板 ---
# 包含系统消息、历史记录占位符和当前用户输入
self.prompt = ChatPromptTemplate.from_messages([
("system", "你是一个乐于助人的 AI 助手,请用简洁明了的语言回答问题。你的名字叫 CherryAI。"),
MessagesPlaceholder(variable_name="chat_history"), # 用于插入历史消息
("human", "{input}"), # 用户当前输入
async def _get_chat_history(self, db: AsyncSession, session_id: str, limit: int = 10) -> Tuple[List[BaseMessage], int]:
"""从数据库加载指定会话的历史消息 (按 order 排序)"""
stmt = (
select(MessageModel)
.filter(MessageModel.session_id == session_id)
.order_by(MessageModel.order.desc()) # Get latest first
.limit(limit)
)
result = await db.execute(stmt)
db_messages = result.scalars().all()
# Convert to LangChain messages (in correct order: oldest to newest)
history: List[BaseMessage] = []
max_order = 0
for msg in reversed(db_messages): # Reverse to get chronological order
if msg.sender == 'user':
history.append(HumanMessage(content=msg.text))
elif msg.sender == 'ai':
history.append(AIMessage(content=msg.text))
# Add handling for 'system' if needed
max_order = max(max_order, msg.order) # Keep track of the latest order number
return history, max_order
async def _save_message(self, db: AsyncSession, session_id: str, sender: str, text: str, order: int):
"""将消息保存到数据库"""
db_message = MessageModel(
session_id=session_id,
sender=sender,
text=text,
order=order
)
db.add(db_message)
await db.flush() # Ensure it's added before potential commit
print(f"消息已保存 (DB): Session={session_id}, Order={order}, Sender={sender}")
async def get_ai_reply(self, db: AsyncSession, user_message: str, session_id: str, assistant_id: str) -> str:
"""获取 AI 回复,并保存用户消息和 AI 回复到数据库"""
# 1. 获取助手配置
assistant = await self.assistant_service.get_assistant(db, assistant_id)
if not assistant:
raise ValueError(f"找不到助手 ID: {assistant_id}")
# 2. 获取历史记录和下一个序号
current_chat_history, last_order = await self._get_chat_history(db, session_id)
user_message_order = last_order + 1
ai_message_order = last_order + 2
# 3. 构建 Prompt
prompt = ChatPromptTemplate.from_messages([
SystemMessage(content=assistant.system_prompt),
MessagesPlaceholder(variable_name="chat_history"),
HumanMessage(content=user_message),
])
# --- 定义输出解析器 ---
# 将 LLM 的输出解析为字符串
self.output_parser = StrOutputParser()
# 4. 获取 LLM
llm = self._get_llm(assistant)
output_parser = StrOutputParser()
chain = prompt | llm | output_parser
# --- 构建 LangChain Expression Language (LCEL) 链 ---
self.chain = self.prompt | self.llm | self.output_parser
async def get_ai_reply(self, user_message: str, session_id: str = "default_session") -> str:
"""
获取 AI 对用户消息的回复 (异步)
Args:
user_message (str): 用户发送的消息
session_id (str): (可选) 用于区分不同对话的会话 ID以支持内存
Returns:
str: AI 的回复文本
Raises:
Exception: 如果调用 AI 服务时发生错误
"""
try:
# --- 获取当前会话的历史记录 (如果需要内存) ---
current_chat_history = chat_history.get(session_id, [])
# --- Save user message BEFORE calling LLM ---
await self._save_message(db, session_id, 'user', user_message, user_message_order)
# --- 使用 ainvoke 进行异步调用 ---
ai_response = await self.chain.ainvoke({
# 5. 调用链获取回复
ai_response = await chain.ainvoke({
"input": user_message,
"chat_history": current_chat_history, # 传入历史记录
"chat_history": current_chat_history, # Pass history fetched from DB
})
# --- 更新会话历史记录 (如果需要内存) ---
# 只保留最近 N 轮对话,防止历史过长
max_history_length = 10 # 保留最近 5 轮对话 (10条消息)
current_chat_history.append(HumanMessage(content=user_message))
current_chat_history.append(AIMessage(content=ai_response))
# 如果历史记录超过长度,移除最早的消息
if len(current_chat_history) > max_history_length:
chat_history[session_id] = current_chat_history[-max_history_length:]
else:
chat_history[session_id] = current_chat_history
# --- Save AI response AFTER getting it ---
await self._save_message(db, session_id, 'ai', ai_response, ai_message_order)
# Note: We don't need to manage history in memory anymore (chat_history_db removed)
return ai_response
except Exception as e:
print(f"调用 LangChain 时出错: {e}")
# 可以进行更细致的错误处理,例如区分 API 错误和内部错误
# Consider rolling back the user message save if LLM call fails,
# although often it's better to keep the user message.
# await db.rollback() # Handled by get_db_session dependency on error
print(f"调用 LangChain 时出错 (助手: {assistant_id}, 会话: {session_id}): {e}")
raise Exception(f"AI 服务暂时不可用: {e}")
async def generate_text(self, prompt_text: str, model_name: str = "gemini-2.0-flash", temperature: float = 0.5) -> str:
# ... (remains the same) ...
try:
temp_llm = ChatGoogleGenerativeAI(
model=model_name,
api_key=self.default_api_key, # 或从助手配置中读取特定 key
temperature=temperature
)
response = await temp_llm.ainvoke(prompt_text)
return response.content
except Exception as e:
print(f"生成文本时出错: {e}")
return "无法生成标题"
# --- 在文件末尾,创建 ChatService 的实例 ---
# 从配置中获取 API Key
from app.core.config import GOOGLE_API_KEY
if not GOOGLE_API_KEY:
raise ValueError("请在 .env 文件中设置 GOOGLE_API_KEY")
chat_service_instance = ChatService(api_key=GOOGLE_API_KEY)
# ChatService instance is now created where needed or injected, no global instance here.

View File

@ -0,0 +1,74 @@
# File: backend/app/services/session_service.py (Update with DB)
# Description: 管理会话数据的服务 (使用 SQLAlchemy)
from typing import List, Optional
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy import delete as sqlalchemy_delete
from app.db.models import SessionModel, AssistantModel # Import DB models
from app.models.pydantic_models import SessionRead, SessionCreateRequest, SessionCreateResponse
from datetime import datetime, timezone
# Import ChatService for title generation (consider refactoring later)
from app.services.chat_service import ChatService
import app.core.config as Config
chat_service_instance = ChatService(Config.GOOGLE_API_KEY)
class SessionService:
"""会话数据的 CRUD 及标题生成服务 (数据库版)"""
async def create_session(self, db: AsyncSession, session_data: SessionCreateRequest) -> SessionCreateResponse:
"""创建新会话并生成标题"""
# 检查助手是否存在
result = await db.execute(select(AssistantModel).filter(AssistantModel.id == session_data.assistant_id))
assistant = result.scalars().first()
if not assistant:
raise ValueError("指定的助手不存在")
# --- 调用 LLM 生成标题 ---
try:
title_prompt = f"根据以下用户第一条消息为此对话生成一个简洁的标题不超过10个字:\n\n{session_data.first_message}"
generated_title = await chat_service_instance.generate_text(title_prompt)
except Exception as e:
print(f"生成会话标题时出错: {e}")
generated_title = f"关于 \"{session_data.first_message[:15]}...\"" # Fallback
# --- 生成结束 ---
db_session = SessionModel(
title=generated_title,
assistant_id=session_data.assistant_id
# ID and created_at have defaults
)
db.add(db_session)
await db.flush()
await db.refresh(db_session)
print(f"会话已创建 (DB): {db_session.id}")
return SessionCreateResponse(
id=db_session.id,
title=db_session.title,
assistant_id=db_session.assistant_id,
created_at=db_session.created_at.isoformat() # Use datetime from DB model
)
async def get_sessions_by_assistant(self, db: AsyncSession, assistant_id: str) -> List[SessionRead]:
"""获取指定助手的所有会话"""
stmt = select(SessionModel).filter(SessionModel.assistant_id == assistant_id).order_by(SessionModel.updated_at.desc()) # Order by update time
result = await db.execute(stmt)
sessions = result.scalars().all()
return [SessionRead.model_validate(s) for s in sessions]
async def get_session(self, db: AsyncSession, session_id: str) -> Optional[SessionRead]:
"""获取单个会话"""
result = await db.execute(select(SessionModel).filter(SessionModel.id == session_id))
session = result.scalars().first()
return SessionRead.model_validate(session) if session else None
async def delete_session(self, db: AsyncSession, session_id: str) -> bool:
"""删除会话"""
stmt = sqlalchemy_delete(SessionModel).where(SessionModel.id == session_id)
result = await db.execute(stmt)
if result.rowcount > 0:
await db.flush()
print(f"会话已删除 (DB): {session_id}")
# Deletion of messages handled by cascade
return True
return False

BIN
backend/cherryai.db Normal file

Binary file not shown.

File diff suppressed because it is too large Load Diff

View File

@ -45,7 +45,7 @@ export default function RootLayout({
<html lang="zh-CN">
<body className={`${geistSans.variable} ${geistMono.variable} antialiased flex h-screen bg-gray-100 dark:bg-gray-900`}>
{/* 侧边栏导航 */}
<aside className="w-16 bg-white dark:bg-gray-800 p-3 shadow-md flex flex-col items-center mr-1"> {/* 调整内边距和对齐 */}
<aside className="w-16 bg-white dark:bg-gray-800 p-3 shadow-md flex flex-col items-center rounded-lg"> {/* 调整内边距和对齐 */}
{/* Logo */}
<div className="mb-6"> {/* 调整 Logo 边距 */}
<Link href="/" className="flex items-center justify-center text-3xl font-bold text-red-600 dark:text-red-500" title="CherryAI 主页">
@ -60,12 +60,11 @@ export default function RootLayout({
<Link
href={item.href}
className="relative flex items-center justify-center p-2 rounded-lg text-gray-600 dark:text-gray-400 hover:bg-red-100 dark:hover:bg-red-900/50 hover:text-red-700 dark:hover:text-red-400 transition-colors duration-200 group" // 居中图标
title={item.name} // 保留原生 title
>
<item.icon className="h-6 w-6 flex-shrink-0 text-gray-600 dark:text-gray-400" />
{/* Tooltip 文字标签 */}
<span
className="absolute left-full top-1/2 -translate-y-1/2 ml-3 px-2 py-1 bg-gray-900 dark:bg-gray-700 text-white text-xs rounded shadow-lg opacity-0 group-hover:opacity-100 transition-opacity duration-200 delay-150 whitespace-nowrap pointer-events-none" // 使用 pointer-events-none 避免干扰悬浮
className="z-10 absolute left-full top-1/2 -translate-y-1/2 ml-3 px-2 py-1 bg-gray-900 dark:bg-gray-700 text-white text-xl rounded shadow-lg opacity-0 group-hover:opacity-100 transition-opacity duration-200 delay-150 whitespace-nowrap pointer-events-none" // 使用 pointer-events-none 避免干扰悬浮
>
{item.name}
</span>

21
frontend/components.json Normal file
View File

@ -0,0 +1,21 @@
{
"$schema": "https://ui.shadcn.com/schema.json",
"style": "new-york",
"rsc": true,
"tsx": true,
"tailwind": {
"config": "tailwind.config.ts",
"css": "styles/globals.css",
"baseColor": "zinc",
"cssVariables": true,
"prefix": ""
},
"aliases": {
"components": "@/components",
"utils": "@/lib/utils",
"ui": "@/components/ui",
"lib": "@/lib",
"hooks": "@/hooks"
},
"iconLibrary": "lucide"
}

View File

@ -0,0 +1,59 @@
import * as React from "react"
import { Slot } from "@radix-ui/react-slot"
import { cva, type VariantProps } from "class-variance-authority"
import { cn } from "@/lib/utils"
const buttonVariants = cva(
"inline-flex items-center justify-center gap-2 whitespace-nowrap rounded-md text-sm font-medium transition-all disabled:pointer-events-none disabled:opacity-50 [&_svg]:pointer-events-none [&_svg:not([class*='size-'])]:size-4 shrink-0 [&_svg]:shrink-0 outline-none focus-visible:border-ring focus-visible:ring-ring/50 focus-visible:ring-[3px] aria-invalid:ring-destructive/20 dark:aria-invalid:ring-destructive/40 aria-invalid:border-destructive",
{
variants: {
variant: {
default:
"bg-primary text-primary-foreground shadow-xs hover:bg-primary/90",
destructive:
"bg-destructive text-white shadow-xs hover:bg-destructive/90 focus-visible:ring-destructive/20 dark:focus-visible:ring-destructive/40 dark:bg-destructive/60",
outline:
"border bg-background shadow-xs hover:bg-accent hover:text-accent-foreground dark:bg-input/30 dark:border-input dark:hover:bg-input/50",
secondary:
"bg-secondary text-secondary-foreground shadow-xs hover:bg-secondary/80",
ghost:
"hover:bg-accent hover:text-accent-foreground dark:hover:bg-accent/50",
link: "text-primary underline-offset-4 hover:underline",
},
size: {
default: "h-9 px-4 py-2 has-[>svg]:px-3",
sm: "h-8 rounded-md gap-1.5 px-3 has-[>svg]:px-2.5",
lg: "h-10 rounded-md px-6 has-[>svg]:px-4",
icon: "size-9",
},
},
defaultVariants: {
variant: "default",
size: "default",
},
}
)
function Button({
className,
variant,
size,
asChild = false,
...props
}: React.ComponentProps<"button"> &
VariantProps<typeof buttonVariants> & {
asChild?: boolean
}) {
const Comp = asChild ? Slot : "button"
return (
<Comp
data-slot="button"
className={cn(buttonVariants({ variant, size, className }))}
{...props}
/>
)
}
export { Button, buttonVariants }

View File

@ -0,0 +1,135 @@
"use client"
import * as React from "react"
import * as DialogPrimitive from "@radix-ui/react-dialog"
import { XIcon } from "lucide-react"
import { cn } from "@/lib/utils"
function Dialog({
...props
}: React.ComponentProps<typeof DialogPrimitive.Root>) {
return <DialogPrimitive.Root data-slot="dialog" {...props} />
}
function DialogTrigger({
...props
}: React.ComponentProps<typeof DialogPrimitive.Trigger>) {
return <DialogPrimitive.Trigger data-slot="dialog-trigger" {...props} />
}
function DialogPortal({
...props
}: React.ComponentProps<typeof DialogPrimitive.Portal>) {
return <DialogPrimitive.Portal data-slot="dialog-portal" {...props} />
}
function DialogClose({
...props
}: React.ComponentProps<typeof DialogPrimitive.Close>) {
return <DialogPrimitive.Close data-slot="dialog-close" {...props} />
}
function DialogOverlay({
className,
...props
}: React.ComponentProps<typeof DialogPrimitive.Overlay>) {
return (
<DialogPrimitive.Overlay
data-slot="dialog-overlay"
className={cn(
"data-[state=open]:animate-in data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=open]:fade-in-0 fixed inset-0 z-50 bg-black/50",
className
)}
{...props}
/>
)
}
function DialogContent({
className,
children,
...props
}: React.ComponentProps<typeof DialogPrimitive.Content>) {
return (
<DialogPortal data-slot="dialog-portal">
<DialogOverlay />
<DialogPrimitive.Content
data-slot="dialog-content"
className={cn(
"bg-background data-[state=open]:animate-in data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=open]:fade-in-0 data-[state=closed]:zoom-out-95 data-[state=open]:zoom-in-95 fixed top-[50%] left-[50%] z-50 grid w-full max-w-[calc(100%-2rem)] translate-x-[-50%] translate-y-[-50%] gap-4 rounded-lg border p-6 shadow-lg duration-200 sm:max-w-lg",
className
)}
{...props}
>
{children}
<DialogPrimitive.Close className="ring-offset-background focus:ring-ring data-[state=open]:bg-accent data-[state=open]:text-muted-foreground absolute top-4 right-4 rounded-xs opacity-70 transition-opacity hover:opacity-100 focus:ring-2 focus:ring-offset-2 focus:outline-hidden disabled:pointer-events-none [&_svg]:pointer-events-none [&_svg]:shrink-0 [&_svg:not([class*='size-'])]:size-4">
<XIcon />
<span className="sr-only">Close</span>
</DialogPrimitive.Close>
</DialogPrimitive.Content>
</DialogPortal>
)
}
function DialogHeader({ className, ...props }: React.ComponentProps<"div">) {
return (
<div
data-slot="dialog-header"
className={cn("flex flex-col gap-2 text-center sm:text-left", className)}
{...props}
/>
)
}
function DialogFooter({ className, ...props }: React.ComponentProps<"div">) {
return (
<div
data-slot="dialog-footer"
className={cn(
"flex flex-col-reverse gap-2 sm:flex-row sm:justify-end",
className
)}
{...props}
/>
)
}
function DialogTitle({
className,
...props
}: React.ComponentProps<typeof DialogPrimitive.Title>) {
return (
<DialogPrimitive.Title
data-slot="dialog-title"
className={cn("text-lg leading-none font-semibold", className)}
{...props}
/>
)
}
function DialogDescription({
className,
...props
}: React.ComponentProps<typeof DialogPrimitive.Description>) {
return (
<DialogPrimitive.Description
data-slot="dialog-description"
className={cn("text-muted-foreground text-sm", className)}
{...props}
/>
)
}
export {
Dialog,
DialogClose,
DialogContent,
DialogDescription,
DialogFooter,
DialogHeader,
DialogOverlay,
DialogPortal,
DialogTitle,
DialogTrigger,
}

View File

@ -0,0 +1,167 @@
"use client"
import * as React from "react"
import * as LabelPrimitive from "@radix-ui/react-label"
import { Slot } from "@radix-ui/react-slot"
import {
Controller,
FormProvider,
useFormContext,
useFormState,
type ControllerProps,
type FieldPath,
type FieldValues,
} from "react-hook-form"
import { cn } from "@/lib/utils"
import { Label } from "@/components/ui/label"
const Form = FormProvider
type FormFieldContextValue<
TFieldValues extends FieldValues = FieldValues,
TName extends FieldPath<TFieldValues> = FieldPath<TFieldValues>,
> = {
name: TName
}
const FormFieldContext = React.createContext<FormFieldContextValue>(
{} as FormFieldContextValue
)
const FormField = <
TFieldValues extends FieldValues = FieldValues,
TName extends FieldPath<TFieldValues> = FieldPath<TFieldValues>,
>({
...props
}: ControllerProps<TFieldValues, TName>) => {
return (
<FormFieldContext.Provider value={{ name: props.name }}>
<Controller {...props} />
</FormFieldContext.Provider>
)
}
const useFormField = () => {
const fieldContext = React.useContext(FormFieldContext)
const itemContext = React.useContext(FormItemContext)
const { getFieldState } = useFormContext()
const formState = useFormState({ name: fieldContext.name })
const fieldState = getFieldState(fieldContext.name, formState)
if (!fieldContext) {
throw new Error("useFormField should be used within <FormField>")
}
const { id } = itemContext
return {
id,
name: fieldContext.name,
formItemId: `${id}-form-item`,
formDescriptionId: `${id}-form-item-description`,
formMessageId: `${id}-form-item-message`,
...fieldState,
}
}
type FormItemContextValue = {
id: string
}
const FormItemContext = React.createContext<FormItemContextValue>(
{} as FormItemContextValue
)
function FormItem({ className, ...props }: React.ComponentProps<"div">) {
const id = React.useId()
return (
<FormItemContext.Provider value={{ id }}>
<div
data-slot="form-item"
className={cn("grid gap-2", className)}
{...props}
/>
</FormItemContext.Provider>
)
}
function FormLabel({
className,
...props
}: React.ComponentProps<typeof LabelPrimitive.Root>) {
const { error, formItemId } = useFormField()
return (
<Label
data-slot="form-label"
data-error={!!error}
className={cn("data-[error=true]:text-destructive", className)}
htmlFor={formItemId}
{...props}
/>
)
}
function FormControl({ ...props }: React.ComponentProps<typeof Slot>) {
const { error, formItemId, formDescriptionId, formMessageId } = useFormField()
return (
<Slot
data-slot="form-control"
id={formItemId}
aria-describedby={
!error
? `${formDescriptionId}`
: `${formDescriptionId} ${formMessageId}`
}
aria-invalid={!!error}
{...props}
/>
)
}
function FormDescription({ className, ...props }: React.ComponentProps<"p">) {
const { formDescriptionId } = useFormField()
return (
<p
data-slot="form-description"
id={formDescriptionId}
className={cn("text-muted-foreground text-sm", className)}
{...props}
/>
)
}
function FormMessage({ className, ...props }: React.ComponentProps<"p">) {
const { error, formMessageId } = useFormField()
const body = error ? String(error?.message ?? "") : props.children
if (!body) {
return null
}
return (
<p
data-slot="form-message"
id={formMessageId}
className={cn("text-destructive text-sm", className)}
{...props}
>
{body}
</p>
)
}
export {
useFormField,
Form,
FormItem,
FormLabel,
FormControl,
FormDescription,
FormMessage,
FormField,
}

View File

@ -0,0 +1,21 @@
import * as React from "react"
import { cn } from "@/lib/utils"
function Input({ className, type, ...props }: React.ComponentProps<"input">) {
return (
<input
type={type}
data-slot="input"
className={cn(
"file:text-foreground placeholder:text-muted-foreground selection:bg-primary selection:text-primary-foreground dark:bg-input/30 border-input flex h-9 w-full min-w-0 rounded-md border bg-transparent px-3 py-1 text-base shadow-xs transition-[color,box-shadow] outline-none file:inline-flex file:h-7 file:border-0 file:bg-transparent file:text-sm file:font-medium disabled:pointer-events-none disabled:cursor-not-allowed disabled:opacity-50 md:text-sm",
"focus-visible:border-ring focus-visible:ring-ring/50 focus-visible:ring-[3px]",
"aria-invalid:ring-destructive/20 dark:aria-invalid:ring-destructive/40 aria-invalid:border-destructive",
className
)}
{...props}
/>
)
}
export { Input }

View File

@ -0,0 +1,24 @@
"use client"
import * as React from "react"
import * as LabelPrimitive from "@radix-ui/react-label"
import { cn } from "@/lib/utils"
function Label({
className,
...props
}: React.ComponentProps<typeof LabelPrimitive.Root>) {
return (
<LabelPrimitive.Root
data-slot="label"
className={cn(
"flex items-center gap-2 text-sm leading-none font-medium select-none group-data-[disabled=true]:pointer-events-none group-data-[disabled=true]:opacity-50 peer-disabled:cursor-not-allowed peer-disabled:opacity-50",
className
)}
{...props}
/>
)
}
export { Label }

View File

@ -0,0 +1,185 @@
"use client"
import * as React from "react"
import * as SelectPrimitive from "@radix-ui/react-select"
import { CheckIcon, ChevronDownIcon, ChevronUpIcon } from "lucide-react"
import { cn } from "@/lib/utils"
function Select({
...props
}: React.ComponentProps<typeof SelectPrimitive.Root>) {
return <SelectPrimitive.Root data-slot="select" {...props} />
}
function SelectGroup({
...props
}: React.ComponentProps<typeof SelectPrimitive.Group>) {
return <SelectPrimitive.Group data-slot="select-group" {...props} />
}
function SelectValue({
...props
}: React.ComponentProps<typeof SelectPrimitive.Value>) {
return <SelectPrimitive.Value data-slot="select-value" {...props} />
}
function SelectTrigger({
className,
size = "default",
children,
...props
}: React.ComponentProps<typeof SelectPrimitive.Trigger> & {
size?: "sm" | "default"
}) {
return (
<SelectPrimitive.Trigger
data-slot="select-trigger"
data-size={size}
className={cn(
"border-input data-[placeholder]:text-muted-foreground [&_svg:not([class*='text-'])]:text-muted-foreground focus-visible:border-ring focus-visible:ring-ring/50 aria-invalid:ring-destructive/20 dark:aria-invalid:ring-destructive/40 aria-invalid:border-destructive dark:bg-input/30 dark:hover:bg-input/50 flex w-fit items-center justify-between gap-2 rounded-md border bg-transparent px-3 py-2 text-sm whitespace-nowrap shadow-xs transition-[color,box-shadow] outline-none focus-visible:ring-[3px] disabled:cursor-not-allowed disabled:opacity-50 data-[size=default]:h-9 data-[size=sm]:h-8 *:data-[slot=select-value]:line-clamp-1 *:data-[slot=select-value]:flex *:data-[slot=select-value]:items-center *:data-[slot=select-value]:gap-2 [&_svg]:pointer-events-none [&_svg]:shrink-0 [&_svg:not([class*='size-'])]:size-4",
className
)}
{...props}
>
{children}
<SelectPrimitive.Icon asChild>
<ChevronDownIcon className="size-4 opacity-50" />
</SelectPrimitive.Icon>
</SelectPrimitive.Trigger>
)
}
function SelectContent({
className,
children,
position = "popper",
...props
}: React.ComponentProps<typeof SelectPrimitive.Content>) {
return (
<SelectPrimitive.Portal>
<SelectPrimitive.Content
data-slot="select-content"
className={cn(
"bg-popover text-popover-foreground data-[state=open]:animate-in data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=open]:fade-in-0 data-[state=closed]:zoom-out-95 data-[state=open]:zoom-in-95 data-[side=bottom]:slide-in-from-top-2 data-[side=left]:slide-in-from-right-2 data-[side=right]:slide-in-from-left-2 data-[side=top]:slide-in-from-bottom-2 relative z-50 max-h-(--radix-select-content-available-height) min-w-[8rem] origin-(--radix-select-content-transform-origin) overflow-x-hidden overflow-y-auto rounded-md border shadow-md",
position === "popper" &&
"data-[side=bottom]:translate-y-1 data-[side=left]:-translate-x-1 data-[side=right]:translate-x-1 data-[side=top]:-translate-y-1",
className
)}
position={position}
{...props}
>
<SelectScrollUpButton />
<SelectPrimitive.Viewport
className={cn(
"p-1",
position === "popper" &&
"h-[var(--radix-select-trigger-height)] w-full min-w-[var(--radix-select-trigger-width)] scroll-my-1"
)}
>
{children}
</SelectPrimitive.Viewport>
<SelectScrollDownButton />
</SelectPrimitive.Content>
</SelectPrimitive.Portal>
)
}
function SelectLabel({
className,
...props
}: React.ComponentProps<typeof SelectPrimitive.Label>) {
return (
<SelectPrimitive.Label
data-slot="select-label"
className={cn("text-muted-foreground px-2 py-1.5 text-xs", className)}
{...props}
/>
)
}
function SelectItem({
className,
children,
...props
}: React.ComponentProps<typeof SelectPrimitive.Item>) {
return (
<SelectPrimitive.Item
data-slot="select-item"
className={cn(
"focus:bg-accent focus:text-accent-foreground [&_svg:not([class*='text-'])]:text-muted-foreground relative flex w-full cursor-default items-center gap-2 rounded-sm py-1.5 pr-8 pl-2 text-sm outline-hidden select-none data-[disabled]:pointer-events-none data-[disabled]:opacity-50 [&_svg]:pointer-events-none [&_svg]:shrink-0 [&_svg:not([class*='size-'])]:size-4 *:[span]:last:flex *:[span]:last:items-center *:[span]:last:gap-2",
className
)}
{...props}
>
<span className="absolute right-2 flex size-3.5 items-center justify-center">
<SelectPrimitive.ItemIndicator>
<CheckIcon className="size-4" />
</SelectPrimitive.ItemIndicator>
</span>
<SelectPrimitive.ItemText>{children}</SelectPrimitive.ItemText>
</SelectPrimitive.Item>
)
}
function SelectSeparator({
className,
...props
}: React.ComponentProps<typeof SelectPrimitive.Separator>) {
return (
<SelectPrimitive.Separator
data-slot="select-separator"
className={cn("bg-border pointer-events-none -mx-1 my-1 h-px", className)}
{...props}
/>
)
}
function SelectScrollUpButton({
className,
...props
}: React.ComponentProps<typeof SelectPrimitive.ScrollUpButton>) {
return (
<SelectPrimitive.ScrollUpButton
data-slot="select-scroll-up-button"
className={cn(
"flex cursor-default items-center justify-center py-1",
className
)}
{...props}
>
<ChevronUpIcon className="size-4" />
</SelectPrimitive.ScrollUpButton>
)
}
function SelectScrollDownButton({
className,
...props
}: React.ComponentProps<typeof SelectPrimitive.ScrollDownButton>) {
return (
<SelectPrimitive.ScrollDownButton
data-slot="select-scroll-down-button"
className={cn(
"flex cursor-default items-center justify-center py-1",
className
)}
{...props}
>
<ChevronDownIcon className="size-4" />
</SelectPrimitive.ScrollDownButton>
)
}
export {
Select,
SelectContent,
SelectGroup,
SelectItem,
SelectLabel,
SelectScrollDownButton,
SelectScrollUpButton,
SelectSeparator,
SelectTrigger,
SelectValue,
}

View File

@ -0,0 +1,13 @@
import { cn } from "@/lib/utils"
function Skeleton({ className, ...props }: React.ComponentProps<"div">) {
return (
<div
data-slot="skeleton"
className={cn("bg-accent animate-pulse rounded-md", className)}
{...props}
/>
)
}
export { Skeleton }

View File

@ -0,0 +1,63 @@
"use client"
import * as React from "react"
import * as SliderPrimitive from "@radix-ui/react-slider"
import { cn } from "@/lib/utils"
function Slider({
className,
defaultValue,
value,
min = 0,
max = 100,
...props
}: React.ComponentProps<typeof SliderPrimitive.Root>) {
const _values = React.useMemo(
() =>
Array.isArray(value)
? value
: Array.isArray(defaultValue)
? defaultValue
: [min, max],
[value, defaultValue, min, max]
)
return (
<SliderPrimitive.Root
data-slot="slider"
defaultValue={defaultValue}
value={value}
min={min}
max={max}
className={cn(
"relative flex w-full touch-none items-center select-none data-[disabled]:opacity-50 data-[orientation=vertical]:h-full data-[orientation=vertical]:min-h-44 data-[orientation=vertical]:w-auto data-[orientation=vertical]:flex-col",
className
)}
{...props}
>
<SliderPrimitive.Track
data-slot="slider-track"
className={cn(
"bg-muted relative grow overflow-hidden rounded-full data-[orientation=horizontal]:h-1.5 data-[orientation=horizontal]:w-full data-[orientation=vertical]:h-full data-[orientation=vertical]:w-1.5"
)}
>
<SliderPrimitive.Range
data-slot="slider-range"
className={cn(
"bg-primary absolute data-[orientation=horizontal]:h-full data-[orientation=vertical]:w-full"
)}
/>
</SliderPrimitive.Track>
{Array.from({ length: _values.length }, (_, index) => (
<SliderPrimitive.Thumb
data-slot="slider-thumb"
key={index}
className="border-primary bg-background ring-ring/50 block size-4 shrink-0 rounded-full border shadow-sm transition-[color,box-shadow] hover:ring-4 focus-visible:ring-4 focus-visible:outline-hidden disabled:pointer-events-none disabled:opacity-50"
/>
))}
</SliderPrimitive.Root>
)
}
export { Slider }

View File

@ -0,0 +1,25 @@
"use client"
import { useTheme } from "next-themes"
import { Toaster as Sonner, ToasterProps } from "sonner"
const Toaster = ({ ...props }: ToasterProps) => {
const { theme = "system" } = useTheme()
return (
<Sonner
theme={theme as ToasterProps["theme"]}
className="toaster group"
style={
{
"--normal-bg": "var(--popover)",
"--normal-text": "var(--popover-foreground)",
"--normal-border": "var(--border)",
} as React.CSSProperties
}
{...props}
/>
)
}
export { Toaster }

View File

@ -0,0 +1,18 @@
import * as React from "react"
import { cn } from "@/lib/utils"
function Textarea({ className, ...props }: React.ComponentProps<"textarea">) {
return (
<textarea
data-slot="textarea"
className={cn(
"border-input placeholder:text-muted-foreground focus-visible:border-ring focus-visible:ring-ring/50 aria-invalid:ring-destructive/20 dark:aria-invalid:ring-destructive/40 aria-invalid:border-destructive dark:bg-input/30 flex field-sizing-content min-h-16 w-full rounded-md border bg-transparent px-3 py-2 text-base shadow-xs transition-[color,box-shadow] outline-none focus-visible:ring-[3px] disabled:cursor-not-allowed disabled:opacity-50 md:text-sm",
className
)}
{...props}
/>
)
}
export { Textarea }

View File

@ -1,14 +1,39 @@
// File: frontend/lib/api.ts (建或修改)
// Description: 用于调用后端 API 的工具函数
// File: frontend/lib/api.ts (新)
// Description: 添加调用助手和会话管理 API 的函数
import { Assistant, AssistantCreateData, AssistantUpdateData } from "@/types/assistant";
import axios from "axios";
// 从环境变量读取后端 API 地址,如果没有则使用默认值
// 确保在 .env.local 文件中定义 NEXT_PUBLIC_API_URL=http://localhost:8000/api/v1
// --- Types ---
export interface Session {
id: string;
title: string;
assistant_id: string;
created_at: string; // ISO date string
updated_at?: string | null; // Add updated_at
}
// Message type from backend
export interface Message {
id: string;
session_id: string;
sender: 'user' | 'ai'; // Or extend with 'system' if needed
text: string;
order: number;
created_at: string; // ISO date string
}
// 聊天响应类型
export interface ChatApiResponse {
reply: string;
session_id?: string | null; // 后端返回的新 session id
session_title?: string | null; // 后端返回的新 session title
}
// --- API Client Setup ---
const API_BASE_URL =
process.env.NEXT_PUBLIC_API_URL || "http://localhost:8000/api/v1";
// 创建 axios 实例,可以设置一些全局配置
const apiClient = axios.create({
baseURL: API_BASE_URL,
headers: {
@ -16,36 +41,135 @@ const apiClient = axios.create({
},
});
// --- Helper for Error Handling ---
const handleApiError = (error: unknown, context: string): string => {
console.error(`API Error (${context}):`, error);
if (axios.isAxiosError(error) && error.response) {
// 尝试提取后端返回的详细错误信息
return (
error.response.data?.detail || `服务器错误 (${error.response.status})`
);
} else if (error instanceof Error) {
return error.message;
}
return "发生未知网络错误";
};
// --- Chat API ---
/**
*
* @param message
* @returns AI
* @throws API
* ()
* @param message
* @param sessionId ID ( 'temp-new-chat')
* @param assistantId ID
* @returns AI
*/
export const sendChatMessage = async (message: string): Promise<string> => {
export const sendChatMessage = async (
message: string,
sessionId: string,
assistantId: string
): Promise<ChatApiResponse> => {
try {
// 发送 POST 请求到后端的 /chat/ 端点
const response = await apiClient.post("/chat/", { message });
// 检查响应数据和 reply 字段是否存在
if (response.data && response.data.reply) {
return response.data.reply; // 返回 AI 的回复
} else {
// 如果响应格式不符合预期,抛出错误
throw new Error("Invalid response format from server");
}
const response = await apiClient.post<ChatApiResponse>("/chat/", {
message,
session_id: sessionId,
assistant_id: assistantId,
});
return response.data; // 返回整个响应体
} catch (error) {
console.error("Error calling chat API:", error); // 在控制台打印详细错误
// 检查是否是 Axios 错误并且有响应体
if (axios.isAxiosError(error) && error.response) {
// 尝试从响应体中获取错误详情,否则提供通用消息
throw new Error(
error.response.data?.detail || "Failed to communicate with server"
);
}
// 如果不是 Axios 错误或没有响应体,抛出通用错误
throw new Error("Failed to send message. Please try again.");
throw new Error(handleApiError(error, "sendChatMessage"));
}
};
// --- 你可以在这里添加调用其他后端 API 的函数 ---
// export const getWorkflows = async () => { ... };
// --- Assistant API ---
/** 获取所有助手列表 */
export const getAssistants = async (): Promise<Assistant[]> => {
try {
const response = await apiClient.get<Assistant[]>("/assistants/");
return response.data;
} catch (error) {
throw new Error(handleApiError(error, "getAssistants"));
}
};
/** 创建新助手 */
export const createAssistant = async (
data: AssistantCreateData
): Promise<Assistant> => {
try {
const response = await apiClient.post<Assistant>("/assistants/", data);
return response.data;
} catch (error) {
throw new Error(handleApiError(error, "createAssistant"));
}
};
/** 更新助手 */
export const updateAssistant = async (
id: string,
data: AssistantUpdateData
): Promise<Assistant> => {
try {
const response = await apiClient.put<Assistant>(`/assistants/${id}`, data);
return response.data;
} catch (error) {
throw new Error(handleApiError(error, "updateAssistant"));
}
};
/** 删除助手 */
export const deleteAssistant = async (id: string): Promise<void> => {
try {
await apiClient.delete(`/assistants/${id}`);
} catch (error) {
throw new Error(handleApiError(error, "deleteAssistant"));
}
};
// --- Session API ---
/** 获取指定助手的所有会话 */
export const getSessionsByAssistant = async (
assistantId: string
): Promise<Session[]> => {
try {
const response = await apiClient.get<Session[]>(
`/sessions/assistant/${assistantId}`
);
return response.data;
} catch (error) {
// 如果助手没有会话,后端可能返回 404 或空列表,这里统一处理为返回空列表
if (axios.isAxiosError(error) && error.response?.status === 404) {
return [];
}
throw new Error(handleApiError(error, "getSessionsByAssistant"));
}
};
/** 删除会话 */
export const deleteSession = async (sessionId: string): Promise<void> => {
try {
await apiClient.delete(`/sessions/${sessionId}`);
} catch (error) {
throw new Error(handleApiError(error, "deleteSession"));
}
};
// 注意:创建会话的 API (POST /sessions/) 在后端被整合到了 POST /chat/ 逻辑中,
//当前端发送 sessionId 为 'temp-new-chat' 的消息时,后端会自动创建。
//如果需要单独创建会话(例如,不发送消息就创建),则需要单独实现前端调用 POST /sessions/。
// --- Message API (New) ---
/** 获取指定会话的消息列表 */
export const getMessagesBySession = async (sessionId: string, limit: number = 100, skip: number = 0): Promise<Message[]> => {
try {
const response = await apiClient.get<Message[]>(`/messages/session/${sessionId}`, {
params: { limit, skip }
});
return response.data;
} catch (error) {
// Handle 404 specifically if needed (session exists but no messages)
if (axios.isAxiosError(error) && error.response?.status === 404) {
return []; // Return empty list if session not found or no messages
}
throw new Error(handleApiError(error, 'getMessagesBySession'));
}
};

6
frontend/lib/utils.ts Normal file
View File

@ -0,0 +1,6 @@
import { clsx, type ClassValue } from "clsx"
import { twMerge } from "tailwind-merge"
export function cn(...inputs: ClassValue[]) {
return twMerge(clsx(inputs))
}

View File

@ -9,11 +9,24 @@
"lint": "next lint"
},
"dependencies": {
"@hookform/resolvers": "^5.0.1",
"@radix-ui/react-dialog": "^1.1.11",
"@radix-ui/react-label": "^2.1.4",
"@radix-ui/react-select": "^2.2.2",
"@radix-ui/react-slider": "^1.3.2",
"@radix-ui/react-slot": "^1.2.0",
"axios": "^1.9.0",
"class-variance-authority": "^0.7.1",
"clsx": "^2.1.1",
"lucide-react": "^0.503.0",
"next": "15.3.1",
"next-themes": "^0.4.6",
"react": "^19.0.0",
"react-dom": "^19.0.0"
"react-dom": "^19.0.0",
"react-hook-form": "^7.56.1",
"sonner": "^2.0.3",
"tailwind-merge": "^3.2.0",
"zod": "^3.24.3"
},
"devDependencies": {
"@eslint/eslintrc": "^3",
@ -26,6 +39,7 @@
"eslint-config-next": "15.3.1",
"postcss": "^8.5.3",
"tailwindcss": "^4",
"tw-animate-css": "^1.2.8",
"typescript": "^5"
},
"packageManager": "pnpm@10.10.0+sha512.d615db246fe70f25dcfea6d8d73dee782ce23e2245e3c4f6f888249fb568149318637dca73c2c5c8ef2a4ca0d5657fb9567188bfab47f566d1ee6ce987815c39"

875
frontend/pnpm-lock.yaml generated

File diff suppressed because it is too large Load Diff

View File

@ -1,3 +1,220 @@
@import "tailwindcss";
@import "tw-animate-css";
@custom-variant dark (&:is(.dark *));
@tailwind utilities;
@theme inline {
--radius-sm: calc(var(--radius) - 4px);
--radius-md: calc(var(--radius) - 2px);
--radius-lg: var(--radius);
--radius-xl: calc(var(--radius) + 4px);
--color-background: var(--background);
--color-foreground: var(--foreground);
--color-card: var(--card);
--color-card-foreground: var(--card-foreground);
--color-popover: var(--popover);
--color-popover-foreground: var(--popover-foreground);
--color-primary: var(--primary);
--color-primary-foreground: var(--primary-foreground);
--color-secondary: var(--secondary);
--color-secondary-foreground: var(--secondary-foreground);
--color-muted: var(--muted);
--color-muted-foreground: var(--muted-foreground);
--color-accent: var(--accent);
--color-accent-foreground: var(--accent-foreground);
--color-destructive: var(--destructive);
--color-border: var(--border);
--color-input: var(--input);
--color-ring: var(--ring);
--color-chart-1: var(--chart-1);
--color-chart-2: var(--chart-2);
--color-chart-3: var(--chart-3);
--color-chart-4: var(--chart-4);
--color-chart-5: var(--chart-5);
--color-sidebar: var(--sidebar);
--color-sidebar-foreground: var(--sidebar-foreground);
--color-sidebar-primary: var(--sidebar-primary);
--color-sidebar-primary-foreground: var(--sidebar-primary-foreground);
--color-sidebar-accent: var(--sidebar-accent);
--color-sidebar-accent-foreground: var(--sidebar-accent-foreground);
--color-sidebar-border: var(--sidebar-border);
--color-sidebar-ring: var(--sidebar-ring);
}
:root {
--radius: 0.625rem;
--background: oklch(1 0 0);
--foreground: oklch(0.141 0.005 285.823);
--card: oklch(1 0 0);
--card-foreground: oklch(0.141 0.005 285.823);
--popover: oklch(1 0 0);
--popover-foreground: oklch(0.141 0.005 285.823);
--primary: oklch(0.21 0.006 285.885);
--primary-foreground: oklch(0.985 0 0);
--secondary: oklch(0.967 0.001 286.375);
--secondary-foreground: oklch(0.21 0.006 285.885);
--muted: oklch(0.967 0.001 286.375);
--muted-foreground: oklch(0.552 0.016 285.938);
--accent: oklch(0.967 0.001 286.375);
--accent-foreground: oklch(0.21 0.006 285.885);
--destructive: oklch(0.577 0.245 27.325);
--border: oklch(0.92 0.004 286.32);
--input: oklch(0.92 0.004 286.32);
--ring: oklch(0.705 0.015 286.067);
--chart-1: oklch(0.646 0.222 41.116);
--chart-2: oklch(0.6 0.118 184.704);
--chart-3: oklch(0.398 0.07 227.392);
--chart-4: oklch(0.828 0.189 84.429);
--chart-5: oklch(0.769 0.188 70.08);
--sidebar: oklch(0.985 0 0);
--sidebar-foreground: oklch(0.141 0.005 285.823);
--sidebar-primary: oklch(0.21 0.006 285.885);
--sidebar-primary-foreground: oklch(0.985 0 0);
--sidebar-accent: oklch(0.967 0.001 286.375);
--sidebar-accent-foreground: oklch(0.21 0.006 285.885);
--sidebar-border: oklch(0.92 0.004 286.32);
--sidebar-ring: oklch(0.705 0.015 286.067);
}
.dark {
--background: oklch(0.141 0.005 285.823);
--foreground: oklch(0.985 0 0);
--card: oklch(0.21 0.006 285.885);
--card-foreground: oklch(0.985 0 0);
--popover: oklch(0.21 0.006 285.885);
--popover-foreground: oklch(0.985 0 0);
--primary: oklch(0.92 0.004 286.32);
--primary-foreground: oklch(0.21 0.006 285.885);
--secondary: oklch(0.274 0.006 286.033);
--secondary-foreground: oklch(0.985 0 0);
--muted: oklch(0.274 0.006 286.033);
--muted-foreground: oklch(0.705 0.015 286.067);
--accent: oklch(0.274 0.006 286.033);
--accent-foreground: oklch(0.985 0 0);
--destructive: oklch(0.704 0.191 22.216);
--border: oklch(1 0 0 / 10%);
--input: oklch(1 0 0 / 15%);
--ring: oklch(0.552 0.016 285.938);
--chart-1: oklch(0.488 0.243 264.376);
--chart-2: oklch(0.696 0.17 162.48);
--chart-3: oklch(0.769 0.188 70.08);
--chart-4: oklch(0.627 0.265 303.9);
--chart-5: oklch(0.645 0.246 16.439);
--sidebar: oklch(0.21 0.006 285.885);
--sidebar-foreground: oklch(0.985 0 0);
--sidebar-primary: oklch(0.488 0.243 264.376);
--sidebar-primary-foreground: oklch(0.985 0 0);
--sidebar-accent: oklch(0.274 0.006 286.033);
--sidebar-accent-foreground: oklch(0.985 0 0);
--sidebar-border: oklch(1 0 0 / 10%);
--sidebar-ring: oklch(0.552 0.016 285.938);
}
@layer base {
* {
@apply border-border outline-ring/50;
}
body {
@apply bg-background text-foreground;
}
}

View File

@ -0,0 +1,23 @@
// --- Types (从后端模型同步或手动定义) ---
// 这些类型应该与后端 pydantic_models.py 中的 Read 模型匹配
export interface Assistant {
id: string;
name: string;
description?: string | null;
avatar?: string | null;
system_prompt: string;
model: string;
temperature: number;
}
// 创建助手时发送的数据类型
export interface AssistantCreateData {
name: string;
description?: string | null;
avatar?: string | null;
system_prompt: string;
model: string;
temperature: number;
}
// 更新助手时发送的数据类型 (所有字段可选)
export type AssistantUpdateData = Partial<AssistantCreateData>;