43 lines
1.5 KiB
Python
43 lines
1.5 KiB
Python
# File: backend/app/db/database.py (New - Database setup)
|
|
# Description: SQLAlchemy 数据库引擎和会话设置
|
|
|
|
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
|
|
from sqlalchemy.orm import sessionmaker, declarative_base
|
|
from app.core.config import DATABASE_URL
|
|
|
|
# 创建异步数据库引擎
|
|
# connect_args={"check_same_thread": False} is needed only for SQLite.
|
|
# It's not needed for other databases.
|
|
engine = create_async_engine(DATABASE_URL, echo=True, connect_args={"check_same_thread": False})
|
|
|
|
# 创建异步会话工厂
|
|
# expire_on_commit=False prevents attributes from expiring after commit.
|
|
AsyncSessionFactory = sessionmaker(
|
|
bind=engine,
|
|
class_=AsyncSession,
|
|
expire_on_commit=False,
|
|
autocommit=False,
|
|
autoflush=False,
|
|
)
|
|
|
|
# 创建数据模型的基础类
|
|
Base = declarative_base()
|
|
|
|
# --- Dependency to get DB session ---
|
|
async def get_db_session() -> AsyncSession:
|
|
"""FastAPI dependency to get an async database session."""
|
|
async with AsyncSessionFactory() as session:
|
|
try:
|
|
yield session
|
|
await session.commit() # Commit transaction if successful
|
|
except Exception:
|
|
await session.rollback() # Rollback on error
|
|
raise
|
|
finally:
|
|
await session.close()
|
|
|
|
# --- Function to create tables (call this on startup) ---
|
|
async def create_db_and_tables():
|
|
async with engine.begin() as conn:
|
|
# await conn.run_sync(Base.metadata.drop_all) # Use drop_all carefully in dev
|
|
await conn.run_sync(Base.metadata.create_all) |