cherry-ai/backend/app/services/assistant_service.py
2025-04-30 04:39:36 +08:00

77 lines
3.5 KiB
Python

# File: backend/app/services/assistant_service.py (Update with DB)
# Description: 管理助手数据的服务 (使用 SQLAlchemy)
from typing import List, Optional
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy import update as sqlalchemy_update, delete as sqlalchemy_delete
from app.db.models import AssistantModel
from app.models.pydantic_models import AssistantCreate, AssistantUpdate, AssistantRead
class AssistantService:
"""助手数据的 CRUD 服务 (数据库版)"""
async def get_assistants(self, db: AsyncSession) -> List[AssistantRead]:
"""获取所有助手"""
result = await db.execute(select(AssistantModel).order_by(AssistantModel.name))
assistants = result.scalars().all()
return [AssistantRead.model_validate(a) for a in assistants] # Use model_validate in Pydantic v2
async def get_assistant(self, db: AsyncSession, assistant_id: str) -> Optional[AssistantRead]:
"""根据 ID 获取单个助手"""
result = await db.execute(select(AssistantModel).filter(AssistantModel.id == assistant_id))
assistant = result.scalars().first()
return AssistantRead.model_validate(assistant) if assistant else None
async def create_assistant(self, db: AsyncSession, assistant_data: AssistantCreate) -> AssistantRead:
"""创建新助手"""
# 使用 Pydantic 模型创建 DB 模型实例
db_assistant = AssistantModel(**assistant_data.model_dump())
# ID will be generated by default in the model
db.add(db_assistant)
await db.flush() # Flush to get the generated ID and defaults
await db.refresh(db_assistant) # Refresh to load all attributes
print(f"助手已创建 (DB): {db_assistant.id} - {db_assistant.name}")
return AssistantRead.model_validate(db_assistant)
async def update_assistant(self, db: AsyncSession, assistant_id: str, assistant_data: AssistantUpdate) -> Optional[AssistantRead]:
"""更新现有助手"""
update_values = assistant_data.model_dump(exclude_unset=True)
if not update_values:
# If nothing to update, just fetch and return the existing one
return await self.get_assistant(db, assistant_id)
# Execute update statement
stmt = (
sqlalchemy_update(AssistantModel)
.where(AssistantModel.id == assistant_id)
.values(**update_values)
.returning(AssistantModel) # Return the updated row
)
result = await db.execute(stmt)
updated_assistant = result.scalars().first()
if updated_assistant:
await db.flush()
await db.refresh(updated_assistant)
print(f"助手已更新 (DB): {assistant_id}")
return AssistantRead.model_validate(updated_assistant)
return None # Assistant not found
async def delete_assistant(self, db: AsyncSession, assistant_id: str) -> bool:
"""删除助手"""
# Prevent deleting default assistant
if assistant_id == 'asst-default': # Assuming 'asst-default' is a known ID
print("尝试删除默认助手 - 操作被阻止")
return False
stmt = sqlalchemy_delete(AssistantModel).where(AssistantModel.id == assistant_id)
result = await db.execute(stmt)
if result.rowcount > 0:
await db.flush()
print(f"助手已删除 (DB): {assistant_id}")
# Deletion of sessions/messages handled by cascade="all, delete-orphan"
return True
return False