mirror of
https://github.com/handsomezhuzhu/QQuiz.git
synced 2026-02-20 12:00:14 +00:00
163 lines
5.0 KiB
Python
163 lines
5.0 KiB
Python
"""
|
|
Database configuration and session management
|
|
"""
|
|
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
|
|
from sqlalchemy.pool import NullPool
|
|
from contextlib import asynccontextmanager
|
|
from typing import AsyncGenerator
|
|
import os
|
|
from dotenv import load_dotenv
|
|
|
|
# Load environment variables
|
|
load_dotenv()
|
|
|
|
# Get database URL from environment
|
|
DATABASE_URL = os.getenv("DATABASE_URL")
|
|
|
|
if not DATABASE_URL:
|
|
raise ValueError("DATABASE_URL environment variable is not set")
|
|
|
|
# Create async engine
|
|
engine = create_async_engine(
|
|
DATABASE_URL,
|
|
echo=False, # Set to True for SQL query logging during development
|
|
future=True,
|
|
poolclass=NullPool if "sqlite" in DATABASE_URL else None,
|
|
)
|
|
|
|
# Create async session factory
|
|
AsyncSessionLocal = async_sessionmaker(
|
|
engine,
|
|
class_=AsyncSession,
|
|
expire_on_commit=False,
|
|
autocommit=False,
|
|
autoflush=False,
|
|
)
|
|
|
|
|
|
async def get_db() -> AsyncGenerator[AsyncSession, None]:
|
|
"""
|
|
Dependency for getting async database session.
|
|
|
|
Usage in FastAPI:
|
|
@app.get("/items")
|
|
async def get_items(db: AsyncSession = Depends(get_db)):
|
|
...
|
|
"""
|
|
async with AsyncSessionLocal() as session:
|
|
try:
|
|
yield session
|
|
await session.commit()
|
|
except Exception:
|
|
await session.rollback()
|
|
raise
|
|
finally:
|
|
await session.close()
|
|
|
|
|
|
async def init_db():
|
|
"""
|
|
Initialize database tables.
|
|
Should be called during application startup.
|
|
"""
|
|
from models import Base
|
|
|
|
async with engine.begin() as conn:
|
|
# Create all tables
|
|
await conn.run_sync(Base.metadata.create_all)
|
|
print("✅ Database tables created successfully")
|
|
|
|
|
|
async def init_default_config(db: AsyncSession):
|
|
"""
|
|
Initialize default system configurations if not exists.
|
|
"""
|
|
from models import SystemConfig, User
|
|
from sqlalchemy import select
|
|
from passlib.context import CryptContext
|
|
|
|
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
|
|
|
# Default configurations
|
|
default_configs = {
|
|
"allow_registration": os.getenv("ALLOW_REGISTRATION", "true"),
|
|
"max_upload_size_mb": os.getenv("MAX_UPLOAD_SIZE_MB", "10"),
|
|
"max_daily_uploads": os.getenv("MAX_DAILY_UPLOADS", "20"),
|
|
"ai_provider": os.getenv("AI_PROVIDER", "openai"),
|
|
}
|
|
|
|
# Validate admin credentials
|
|
admin_username = os.getenv("ADMIN_USERNAME", "admin")
|
|
if not admin_username or len(admin_username) < 3:
|
|
raise ValueError("ADMIN_USERNAME must be at least 3 characters long")
|
|
|
|
admin_password = os.getenv("ADMIN_PASSWORD")
|
|
if not admin_password or len(admin_password) < 12:
|
|
raise ValueError("ADMIN_PASSWORD must be set and at least 12 characters long")
|
|
|
|
for key, value in default_configs.items():
|
|
result = await db.execute(select(SystemConfig).where(SystemConfig.key == key))
|
|
existing = result.scalar_one_or_none()
|
|
|
|
if not existing:
|
|
config = SystemConfig(key=key, value=str(value))
|
|
db.add(config)
|
|
print(f"✅ Created default config: {key} = {value}")
|
|
|
|
# Create or update default admin user
|
|
result = await db.execute(select(User).where(User.username == admin_username))
|
|
admin = result.scalar_one_or_none()
|
|
|
|
default_admin_id = admin.id if admin else None
|
|
|
|
if not admin:
|
|
admin_user = User(
|
|
username=admin_username,
|
|
hashed_password=pwd_context.hash(admin_password),
|
|
is_admin=True
|
|
)
|
|
db.add(admin_user)
|
|
await db.commit()
|
|
await db.refresh(admin_user)
|
|
default_admin_id = admin_user.id
|
|
print(f"✅ Created default admin user (username: {admin_username})")
|
|
else:
|
|
# Update password if it has changed (verify current password doesn't match)
|
|
if not pwd_context.verify(admin_password, admin.hashed_password):
|
|
admin.hashed_password = pwd_context.hash(admin_password)
|
|
print(f"🔄 Updated default admin password (username: {admin_username})")
|
|
await db.commit()
|
|
|
|
if default_admin_id is not None:
|
|
result = await db.execute(
|
|
select(SystemConfig).where(SystemConfig.key == "default_admin_id")
|
|
)
|
|
default_admin_config = result.scalar_one_or_none()
|
|
|
|
if not default_admin_config:
|
|
db.add(SystemConfig(key="default_admin_id", value=str(default_admin_id)))
|
|
elif default_admin_config.value != str(default_admin_id):
|
|
default_admin_config.value = str(default_admin_id)
|
|
|
|
await db.commit()
|
|
|
|
|
|
@asynccontextmanager
|
|
async def get_db_context():
|
|
"""
|
|
Context manager for getting database session outside of FastAPI dependency injection.
|
|
|
|
Usage:
|
|
async with get_db_context() as db:
|
|
result = await db.execute(...)
|
|
"""
|
|
async with AsyncSessionLocal() as session:
|
|
try:
|
|
yield session
|
|
await session.commit()
|
|
except Exception:
|
|
await session.rollback()
|
|
raise
|
|
finally:
|
|
await session.close()
|