mirror of
https://github.com/handsomezhuzhu/QQuiz.git
synced 2026-02-20 12:00:14 +00:00
🎉 Initial commit: QQuiz - 智能刷题与题库管理平台
## 功能特性 ✅ **核心功能** - 多文件上传与智能去重(基于 content_hash) - 异步文档解析(支持 TXT/PDF/DOCX/XLSX) - AI 智能题目提取与评分(OpenAI/Anthropic/Qwen) - 断点续做与进度管理 - 自动错题本收集 ✅ **技术栈** - Backend: FastAPI + SQLAlchemy 2.0 + PostgreSQL - Frontend: React 18 + Vite + Tailwind CSS - Deployment: Docker Compose ✅ **项目结构** - 53 个文件 - 完整的前后端分离架构 - Docker/源码双模部署支持 🚀 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
21
backend/.dockerignore
Normal file
21
backend/.dockerignore
Normal file
@@ -0,0 +1,21 @@
|
||||
__pycache__
|
||||
*.pyc
|
||||
*.pyo
|
||||
*.pyd
|
||||
.Python
|
||||
venv/
|
||||
env/
|
||||
ENV/
|
||||
.venv
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
.pytest_cache/
|
||||
.coverage
|
||||
htmlcov/
|
||||
dist/
|
||||
build/
|
||||
*.egg-info/
|
||||
.DS_Store
|
||||
.env
|
||||
uploads/
|
||||
*.sqlite3
|
||||
25
backend/Dockerfile
Normal file
25
backend/Dockerfile
Normal file
@@ -0,0 +1,25 @@
|
||||
FROM python:3.11-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Install system dependencies
|
||||
RUN apt-get update && apt-get install -y \
|
||||
gcc \
|
||||
postgresql-client \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Copy requirements and install Python dependencies
|
||||
COPY requirements.txt .
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# Copy application code
|
||||
COPY . .
|
||||
|
||||
# Create uploads directory
|
||||
RUN mkdir -p uploads
|
||||
|
||||
# Expose port
|
||||
EXPOSE 8000
|
||||
|
||||
# Run database migrations and start server
|
||||
CMD alembic upgrade head && uvicorn main:app --host 0.0.0.0 --port 8000
|
||||
94
backend/alembic.ini
Normal file
94
backend/alembic.ini
Normal file
@@ -0,0 +1,94 @@
|
||||
# A generic, single database configuration.
|
||||
|
||||
[alembic]
|
||||
# path to migration scripts
|
||||
script_location = alembic
|
||||
|
||||
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
|
||||
file_template = %%(year)d%%(month).2d%%(day).2d_%%(hour).2d%%(minute).2d_%%(rev)s_%%(slug)s
|
||||
|
||||
# sys.path path, will be prepended to sys.path if present.
|
||||
prepend_sys_path = .
|
||||
|
||||
# timezone to use when rendering the date within the migration file
|
||||
# as well as the filename.
|
||||
timezone = UTC
|
||||
|
||||
# max length of characters to apply to the
|
||||
# "slug" field
|
||||
truncate_slug_length = 40
|
||||
|
||||
# set to 'true' to run the environment during
|
||||
# the 'revision' command, regardless of autogenerate
|
||||
revision_environment = false
|
||||
|
||||
# set to 'true' to allow .pyc and .pyo files without
|
||||
# a source .py file to be detected as revisions in the
|
||||
# versions/ directory
|
||||
sourceless = false
|
||||
|
||||
# version location specification; This defaults
|
||||
# to alembic/versions. When using multiple version
|
||||
# directories, initial revisions must be specified with --version-path.
|
||||
version_locations = %(here)s/alembic/versions
|
||||
|
||||
# version path separator; As mentioned above, this is the character used to split
|
||||
# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep.
|
||||
# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas.
|
||||
version_path_separator = os # Use os.pathsep. Default configuration used for new projects.
|
||||
|
||||
# set to 'true' to search source files recursively
|
||||
# in each "version_locations" directory
|
||||
recursive_version_locations = false
|
||||
|
||||
# the output encoding used when revision files
|
||||
# are written from script.py.mako
|
||||
output_encoding = utf-8
|
||||
|
||||
# sqlalchemy.url will be read from environment variable DATABASE_URL
|
||||
|
||||
[post_write_hooks]
|
||||
# post_write_hooks defines scripts or Python functions that are run
|
||||
# on newly generated revision scripts. See the documentation for further
|
||||
# detail and examples
|
||||
|
||||
# format using "black" - use the console_scripts runner, against the "black" entrypoint
|
||||
# hooks = black
|
||||
# black.type = console_scripts
|
||||
# black.entrypoint = black
|
||||
# black.options = -l 79 REVISION_SCRIPT_FILENAME
|
||||
|
||||
# Logging configuration
|
||||
[loggers]
|
||||
keys = root,sqlalchemy,alembic
|
||||
|
||||
[handlers]
|
||||
keys = console
|
||||
|
||||
[formatters]
|
||||
keys = generic
|
||||
|
||||
[logger_root]
|
||||
level = WARN
|
||||
handlers = console
|
||||
qualname =
|
||||
|
||||
[logger_sqlalchemy]
|
||||
level = WARN
|
||||
handlers =
|
||||
qualname = sqlalchemy.engine
|
||||
|
||||
[logger_alembic]
|
||||
level = INFO
|
||||
handlers =
|
||||
qualname = alembic
|
||||
|
||||
[handler_console]
|
||||
class = StreamHandler
|
||||
args = (sys.stderr,)
|
||||
level = NOTSET
|
||||
formatter = generic
|
||||
|
||||
[formatter_generic]
|
||||
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||
datefmt = %H:%M:%S
|
||||
81
backend/alembic/env.py
Normal file
81
backend/alembic/env.py
Normal file
@@ -0,0 +1,81 @@
|
||||
"""
|
||||
Alembic environment configuration for async SQLAlchemy
|
||||
"""
|
||||
from logging.config import fileConfig
|
||||
from sqlalchemy import pool
|
||||
from sqlalchemy.engine import Connection
|
||||
from sqlalchemy.ext.asyncio import async_engine_from_config
|
||||
from alembic import context
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Add parent directory to path
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
# Import models
|
||||
from models import Base
|
||||
|
||||
# this is the Alembic Config object
|
||||
config = context.config
|
||||
|
||||
# Set database URL from environment
|
||||
config.set_main_option("sqlalchemy.url", os.getenv("DATABASE_URL", ""))
|
||||
|
||||
# Interpret the config file for Python logging.
|
||||
if config.config_file_name is not None:
|
||||
fileConfig(config.config_file_name)
|
||||
|
||||
# add your model's MetaData object here for 'autogenerate' support
|
||||
target_metadata = Base.metadata
|
||||
|
||||
|
||||
def run_migrations_offline() -> None:
|
||||
"""Run migrations in 'offline' mode."""
|
||||
url = config.get_main_option("sqlalchemy.url")
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata,
|
||||
literal_binds=True,
|
||||
dialect_opts={"paramstyle": "named"},
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
def do_run_migrations(connection: Connection) -> None:
|
||||
"""Run migrations with given connection"""
|
||||
context.configure(connection=connection, target_metadata=target_metadata)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
async def run_async_migrations() -> None:
|
||||
"""Run migrations in 'online' mode (async)"""
|
||||
connectable = async_engine_from_config(
|
||||
config.get_section(config.config_ini_section, {}),
|
||||
prefix="sqlalchemy.",
|
||||
poolclass=pool.NullPool,
|
||||
)
|
||||
|
||||
async with connectable.connect() as connection:
|
||||
await connection.run_sync(do_run_migrations)
|
||||
|
||||
await connectable.dispose()
|
||||
|
||||
|
||||
def run_migrations_online() -> None:
|
||||
"""Run migrations in 'online' mode."""
|
||||
asyncio.run(run_async_migrations())
|
||||
|
||||
|
||||
if context.is_offline_mode():
|
||||
run_migrations_offline()
|
||||
else:
|
||||
run_migrations_online()
|
||||
24
backend/alembic/script.py.mako
Normal file
24
backend/alembic/script.py.mako
Normal file
@@ -0,0 +1,24 @@
|
||||
"""${message}
|
||||
|
||||
Revision ID: ${up_revision}
|
||||
Revises: ${down_revision | comma,n}
|
||||
Create Date: ${create_date}
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
${imports if imports else ""}
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = ${repr(up_revision)}
|
||||
down_revision = ${repr(down_revision)}
|
||||
branch_labels = ${repr(branch_labels)}
|
||||
depends_on = ${repr(depends_on)}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
${upgrades if upgrades else "pass"}
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
${downgrades if downgrades else "pass"}
|
||||
1
backend/alembic/versions/.gitkeep
Normal file
1
backend/alembic/versions/.gitkeep
Normal file
@@ -0,0 +1 @@
|
||||
# This file ensures the versions directory is tracked by git
|
||||
132
backend/database.py
Normal file
132
backend/database.py
Normal file
@@ -0,0 +1,132 @@
|
||||
"""
|
||||
Database configuration and session management
|
||||
"""
|
||||
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
|
||||
from sqlalchemy.pool import NullPool
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import AsyncGenerator
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
# Get database URL from environment
|
||||
DATABASE_URL = os.getenv("DATABASE_URL")
|
||||
|
||||
if not DATABASE_URL:
|
||||
raise ValueError("DATABASE_URL environment variable is not set")
|
||||
|
||||
# Create async engine
|
||||
engine = create_async_engine(
|
||||
DATABASE_URL,
|
||||
echo=False, # Set to True for SQL query logging during development
|
||||
future=True,
|
||||
poolclass=NullPool if "sqlite" in DATABASE_URL else None,
|
||||
)
|
||||
|
||||
# Create async session factory
|
||||
AsyncSessionLocal = async_sessionmaker(
|
||||
engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
autocommit=False,
|
||||
autoflush=False,
|
||||
)
|
||||
|
||||
|
||||
async def get_db() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""
|
||||
Dependency for getting async database session.
|
||||
|
||||
Usage in FastAPI:
|
||||
@app.get("/items")
|
||||
async def get_items(db: AsyncSession = Depends(get_db)):
|
||||
...
|
||||
"""
|
||||
async with AsyncSessionLocal() as session:
|
||||
try:
|
||||
yield session
|
||||
await session.commit()
|
||||
except Exception:
|
||||
await session.rollback()
|
||||
raise
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
|
||||
async def init_db():
|
||||
"""
|
||||
Initialize database tables.
|
||||
Should be called during application startup.
|
||||
"""
|
||||
from models import Base
|
||||
|
||||
async with engine.begin() as conn:
|
||||
# Create all tables
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
print("✅ Database tables created successfully")
|
||||
|
||||
|
||||
async def init_default_config(db: AsyncSession):
|
||||
"""
|
||||
Initialize default system configurations if not exists.
|
||||
"""
|
||||
from models import SystemConfig, User
|
||||
from sqlalchemy import select
|
||||
from passlib.context import CryptContext
|
||||
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
# Default configurations
|
||||
default_configs = {
|
||||
"allow_registration": os.getenv("ALLOW_REGISTRATION", "true"),
|
||||
"max_upload_size_mb": os.getenv("MAX_UPLOAD_SIZE_MB", "10"),
|
||||
"max_daily_uploads": os.getenv("MAX_DAILY_UPLOADS", "20"),
|
||||
"ai_provider": os.getenv("AI_PROVIDER", "openai"),
|
||||
}
|
||||
|
||||
for key, value in default_configs.items():
|
||||
result = await db.execute(select(SystemConfig).where(SystemConfig.key == key))
|
||||
existing = result.scalar_one_or_none()
|
||||
|
||||
if not existing:
|
||||
config = SystemConfig(key=key, value=str(value))
|
||||
db.add(config)
|
||||
print(f"✅ Created default config: {key} = {value}")
|
||||
|
||||
# Create default admin user if not exists
|
||||
result = await db.execute(select(User).where(User.username == "admin"))
|
||||
admin = result.scalar_one_or_none()
|
||||
|
||||
if not admin:
|
||||
admin_user = User(
|
||||
username="admin",
|
||||
hashed_password=pwd_context.hash("admin123"), # Change this password!
|
||||
is_admin=True
|
||||
)
|
||||
db.add(admin_user)
|
||||
print("✅ Created default admin user (username: admin, password: admin123)")
|
||||
print("⚠️ IMPORTANT: Please change the admin password immediately!")
|
||||
|
||||
await db.commit()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_db_context():
|
||||
"""
|
||||
Context manager for getting database session outside of FastAPI dependency injection.
|
||||
|
||||
Usage:
|
||||
async with get_db_context() as db:
|
||||
result = await db.execute(...)
|
||||
"""
|
||||
async with AsyncSessionLocal() as session:
|
||||
try:
|
||||
yield session
|
||||
await session.commit()
|
||||
except Exception:
|
||||
await session.rollback()
|
||||
raise
|
||||
finally:
|
||||
await session.close()
|
||||
84
backend/main.py
Normal file
84
backend/main.py
Normal file
@@ -0,0 +1,84 @@
|
||||
"""
|
||||
QQuiz FastAPI Application
|
||||
"""
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from contextlib import asynccontextmanager
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from database import init_db, init_default_config, get_db_context
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Application lifespan events"""
|
||||
# Startup
|
||||
print("🚀 Starting QQuiz Application...")
|
||||
|
||||
# Initialize database
|
||||
await init_db()
|
||||
|
||||
# Initialize default configurations
|
||||
async with get_db_context() as db:
|
||||
await init_default_config(db)
|
||||
|
||||
# Create uploads directory
|
||||
upload_dir = os.getenv("UPLOAD_DIR", "./uploads")
|
||||
os.makedirs(upload_dir, exist_ok=True)
|
||||
print(f"📁 Upload directory: {upload_dir}")
|
||||
|
||||
print("✅ Application started successfully!")
|
||||
|
||||
yield
|
||||
|
||||
# Shutdown
|
||||
print("👋 Shutting down QQuiz Application...")
|
||||
|
||||
|
||||
# Create FastAPI app
|
||||
app = FastAPI(
|
||||
title="QQuiz API",
|
||||
description="智能刷题与题库管理平台",
|
||||
version="1.0.0",
|
||||
lifespan=lifespan
|
||||
)
|
||||
|
||||
# Configure CORS
|
||||
cors_origins = os.getenv("CORS_ORIGINS", "http://localhost:3000").split(",")
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=cors_origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
"""Root endpoint"""
|
||||
return {
|
||||
"message": "Welcome to QQuiz API",
|
||||
"version": "1.0.0",
|
||||
"docs": "/docs"
|
||||
}
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
"""Health check endpoint"""
|
||||
return {"status": "healthy"}
|
||||
|
||||
|
||||
# Import and include routers
|
||||
from routers import auth, exam, question, mistake, admin
|
||||
|
||||
app.include_router(auth.router, prefix="/api/auth", tags=["Authentication"])
|
||||
app.include_router(exam.router, prefix="/api/exams", tags=["Exams"])
|
||||
app.include_router(question.router, prefix="/api/questions", tags=["Questions"])
|
||||
app.include_router(mistake.router, prefix="/api/mistakes", tags=["Mistakes"])
|
||||
app.include_router(admin.router, prefix="/api/admin", tags=["Admin"])
|
||||
134
backend/models.py
Normal file
134
backend/models.py
Normal file
@@ -0,0 +1,134 @@
|
||||
"""
|
||||
SQLAlchemy Models for QQuiz Platform
|
||||
"""
|
||||
from datetime import datetime
|
||||
from enum import Enum as PyEnum
|
||||
from sqlalchemy import (
|
||||
Column, Integer, String, Boolean, DateTime,
|
||||
ForeignKey, Text, JSON, Index, Enum
|
||||
)
|
||||
from sqlalchemy.orm import relationship, declarative_base
|
||||
from sqlalchemy.sql import func
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
class ExamStatus(str, PyEnum):
|
||||
"""Exam processing status"""
|
||||
PENDING = "pending"
|
||||
PROCESSING = "processing"
|
||||
READY = "ready"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
class QuestionType(str, PyEnum):
|
||||
"""Question types"""
|
||||
SINGLE = "single" # 单选
|
||||
MULTIPLE = "multiple" # 多选
|
||||
JUDGE = "judge" # 判断
|
||||
SHORT = "short" # 简答
|
||||
|
||||
|
||||
class User(Base):
|
||||
"""User model"""
|
||||
__tablename__ = "users"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
username = Column(String(50), unique=True, nullable=False, index=True)
|
||||
hashed_password = Column(String(255), nullable=False)
|
||||
is_admin = Column(Boolean, default=False, nullable=False)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
|
||||
|
||||
# Relationships
|
||||
exams = relationship("Exam", back_populates="user", cascade="all, delete-orphan")
|
||||
mistakes = relationship("UserMistake", back_populates="user", cascade="all, delete-orphan")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<User(id={self.id}, username='{self.username}', is_admin={self.is_admin})>"
|
||||
|
||||
|
||||
class SystemConfig(Base):
|
||||
"""System configuration key-value store"""
|
||||
__tablename__ = "system_configs"
|
||||
|
||||
key = Column(String(100), primary_key=True)
|
||||
value = Column(Text, nullable=False)
|
||||
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now())
|
||||
|
||||
def __repr__(self):
|
||||
return f"<SystemConfig(key='{self.key}', value='{self.value}')>"
|
||||
|
||||
|
||||
class Exam(Base):
|
||||
"""Exam (Question Bank Container)"""
|
||||
__tablename__ = "exams"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
user_id = Column(Integer, ForeignKey("users.id", ondelete="CASCADE"), nullable=False)
|
||||
title = Column(String(200), nullable=False)
|
||||
status = Column(Enum(ExamStatus), default=ExamStatus.PENDING, nullable=False, index=True)
|
||||
current_index = Column(Integer, default=0, nullable=False)
|
||||
total_questions = Column(Integer, default=0, nullable=False)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
|
||||
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now())
|
||||
|
||||
# Relationships
|
||||
user = relationship("User", back_populates="exams")
|
||||
questions = relationship("Question", back_populates="exam", cascade="all, delete-orphan")
|
||||
|
||||
# Indexes
|
||||
__table_args__ = (
|
||||
Index('ix_exams_user_status', 'user_id', 'status'),
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Exam(id={self.id}, title='{self.title}', status={self.status}, questions={self.total_questions})>"
|
||||
|
||||
|
||||
class Question(Base):
|
||||
"""Question model"""
|
||||
__tablename__ = "questions"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
exam_id = Column(Integer, ForeignKey("exams.id", ondelete="CASCADE"), nullable=False)
|
||||
content = Column(Text, nullable=False)
|
||||
type = Column(Enum(QuestionType), nullable=False)
|
||||
options = Column(JSON, nullable=True) # For single/multiple choice: ["A. Option1", "B. Option2", ...]
|
||||
answer = Column(Text, nullable=False)
|
||||
analysis = Column(Text, nullable=True)
|
||||
content_hash = Column(String(32), nullable=False, index=True) # MD5 hash for deduplication
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
|
||||
|
||||
# Relationships
|
||||
exam = relationship("Exam", back_populates="questions")
|
||||
mistakes = relationship("UserMistake", back_populates="question", cascade="all, delete-orphan")
|
||||
|
||||
# Indexes for deduplication within exam scope
|
||||
__table_args__ = (
|
||||
Index('ix_questions_exam_hash', 'exam_id', 'content_hash'),
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Question(id={self.id}, type={self.type}, hash={self.content_hash[:8]}...)>"
|
||||
|
||||
|
||||
class UserMistake(Base):
|
||||
"""User mistake records (错题本)"""
|
||||
__tablename__ = "user_mistakes"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
user_id = Column(Integer, ForeignKey("users.id", ondelete="CASCADE"), nullable=False)
|
||||
question_id = Column(Integer, ForeignKey("questions.id", ondelete="CASCADE"), nullable=False)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
|
||||
|
||||
# Relationships
|
||||
user = relationship("User", back_populates="mistakes")
|
||||
question = relationship("Question", back_populates="mistakes")
|
||||
|
||||
# Unique constraint to prevent duplicates
|
||||
__table_args__ = (
|
||||
Index('ix_user_mistakes_unique', 'user_id', 'question_id', unique=True),
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<UserMistake(user_id={self.user_id}, question_id={self.question_id})>"
|
||||
18
backend/requirements.txt
Normal file
18
backend/requirements.txt
Normal file
@@ -0,0 +1,18 @@
|
||||
fastapi==0.109.0
|
||||
uvicorn[standard]==0.27.0
|
||||
sqlalchemy==2.0.25
|
||||
asyncpg==0.29.0
|
||||
alembic==1.13.1
|
||||
pydantic==2.5.3
|
||||
pydantic-settings==2.1.0
|
||||
python-dotenv==1.0.0
|
||||
python-multipart==0.0.6
|
||||
passlib[bcrypt]==1.7.4
|
||||
python-jose[cryptography]==3.3.0
|
||||
aiofiles==23.2.1
|
||||
httpx==0.26.0
|
||||
openai==1.10.0
|
||||
anthropic==0.8.1
|
||||
python-docx==1.1.0
|
||||
PyPDF2==3.0.1
|
||||
openpyxl==3.1.2
|
||||
6
backend/routers/__init__.py
Normal file
6
backend/routers/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""
|
||||
Routers package
|
||||
"""
|
||||
from . import auth, exam, question, mistake, admin
|
||||
|
||||
__all__ = ["auth", "exam", "question", "mistake", "admin"]
|
||||
63
backend/routers/admin.py
Normal file
63
backend/routers/admin.py
Normal file
@@ -0,0 +1,63 @@
|
||||
"""
|
||||
Admin Router
|
||||
"""
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
|
||||
from database import get_db
|
||||
from models import User, SystemConfig
|
||||
from schemas import SystemConfigUpdate, SystemConfigResponse
|
||||
from services.auth_service import get_current_admin_user
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/config", response_model=SystemConfigResponse)
|
||||
async def get_system_config(
|
||||
current_admin: User = Depends(get_current_admin_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Get system configuration (admin only)"""
|
||||
|
||||
# Fetch all config values
|
||||
result = await db.execute(select(SystemConfig))
|
||||
configs = {config.key: config.value for config in result.scalars().all()}
|
||||
|
||||
return {
|
||||
"allow_registration": configs.get("allow_registration", "true").lower() == "true",
|
||||
"max_upload_size_mb": int(configs.get("max_upload_size_mb", "10")),
|
||||
"max_daily_uploads": int(configs.get("max_daily_uploads", "20")),
|
||||
"ai_provider": configs.get("ai_provider", "openai")
|
||||
}
|
||||
|
||||
|
||||
@router.put("/config", response_model=SystemConfigResponse)
|
||||
async def update_system_config(
|
||||
config_update: SystemConfigUpdate,
|
||||
current_admin: User = Depends(get_current_admin_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Update system configuration (admin only)"""
|
||||
|
||||
update_data = config_update.dict(exclude_unset=True)
|
||||
|
||||
for key, value in update_data.items():
|
||||
result = await db.execute(
|
||||
select(SystemConfig).where(SystemConfig.key == key)
|
||||
)
|
||||
config = result.scalar_one_or_none()
|
||||
|
||||
if config:
|
||||
config.value = str(value).lower() if isinstance(value, bool) else str(value)
|
||||
else:
|
||||
new_config = SystemConfig(
|
||||
key=key,
|
||||
value=str(value).lower() if isinstance(value, bool) else str(value)
|
||||
)
|
||||
db.add(new_config)
|
||||
|
||||
await db.commit()
|
||||
|
||||
# Return updated config
|
||||
return await get_system_config(current_admin, db)
|
||||
130
backend/routers/auth.py
Normal file
130
backend/routers/auth.py
Normal file
@@ -0,0 +1,130 @@
|
||||
"""
|
||||
Authentication Router
|
||||
"""
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
from datetime import timedelta
|
||||
|
||||
from database import get_db
|
||||
from models import User, SystemConfig
|
||||
from schemas import UserCreate, UserLogin, Token, UserResponse
|
||||
from utils import hash_password, verify_password, create_access_token
|
||||
from services.auth_service import get_current_user
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def register(
|
||||
user_data: UserCreate,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Register a new user"""
|
||||
|
||||
# Check if registration is allowed
|
||||
result = await db.execute(
|
||||
select(SystemConfig).where(SystemConfig.key == "allow_registration")
|
||||
)
|
||||
config = result.scalar_one_or_none()
|
||||
|
||||
if config and config.value.lower() == "false":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Registration is currently disabled"
|
||||
)
|
||||
|
||||
# Check if username already exists
|
||||
result = await db.execute(
|
||||
select(User).where(User.username == user_data.username)
|
||||
)
|
||||
existing_user = result.scalar_one_or_none()
|
||||
|
||||
if existing_user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Username already registered"
|
||||
)
|
||||
|
||||
# Create new user
|
||||
new_user = User(
|
||||
username=user_data.username,
|
||||
hashed_password=hash_password(user_data.password),
|
||||
is_admin=False
|
||||
)
|
||||
|
||||
db.add(new_user)
|
||||
await db.commit()
|
||||
await db.refresh(new_user)
|
||||
|
||||
return new_user
|
||||
|
||||
|
||||
@router.post("/login", response_model=Token)
|
||||
async def login(
|
||||
user_data: UserLogin,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Login and get access token"""
|
||||
|
||||
# Find user
|
||||
result = await db.execute(
|
||||
select(User).where(User.username == user_data.username)
|
||||
)
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
# Verify credentials
|
||||
if not user or not verify_password(user_data.password, user.hashed_password):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Incorrect username or password",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
# Create access token
|
||||
access_token = create_access_token(
|
||||
data={"sub": user.id}
|
||||
)
|
||||
|
||||
return {
|
||||
"access_token": access_token,
|
||||
"token_type": "bearer"
|
||||
}
|
||||
|
||||
|
||||
@router.get("/me", response_model=UserResponse)
|
||||
async def get_current_user_info(
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""Get current user information"""
|
||||
return current_user
|
||||
|
||||
|
||||
@router.post("/change-password")
|
||||
async def change_password(
|
||||
old_password: str,
|
||||
new_password: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Change user password"""
|
||||
|
||||
# Verify old password
|
||||
if not verify_password(old_password, current_user.hashed_password):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Incorrect current password"
|
||||
)
|
||||
|
||||
# Validate new password
|
||||
if len(new_password) < 6:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="New password must be at least 6 characters"
|
||||
)
|
||||
|
||||
# Update password
|
||||
current_user.hashed_password = hash_password(new_password)
|
||||
await db.commit()
|
||||
|
||||
return {"message": "Password changed successfully"}
|
||||
411
backend/routers/exam.py
Normal file
411
backend/routers/exam.py
Normal file
@@ -0,0 +1,411 @@
|
||||
"""
|
||||
Exam Router - Handles exam creation, file upload, and deduplication
|
||||
"""
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, UploadFile, File, Form, BackgroundTasks
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, func, and_
|
||||
from typing import List, Optional
|
||||
from datetime import datetime, timedelta
|
||||
import os
|
||||
import aiofiles
|
||||
|
||||
from database import get_db
|
||||
from models import User, Exam, Question, ExamStatus, SystemConfig
|
||||
from schemas import (
|
||||
ExamCreate, ExamResponse, ExamListResponse,
|
||||
ExamUploadResponse, ParseResult, QuizProgressUpdate
|
||||
)
|
||||
from services.auth_service import get_current_user
|
||||
from services.document_parser import document_parser
|
||||
from services.llm_service import llm_service
|
||||
from utils import is_allowed_file, calculate_content_hash
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
async def check_upload_limits(user_id: int, file_size: int, db: AsyncSession):
|
||||
"""Check if user has exceeded upload limits"""
|
||||
|
||||
# Get max upload size config
|
||||
result = await db.execute(
|
||||
select(SystemConfig).where(SystemConfig.key == "max_upload_size_mb")
|
||||
)
|
||||
config = result.scalar_one_or_none()
|
||||
max_size_mb = int(config.value) if config else 10
|
||||
|
||||
# Check file size
|
||||
if file_size > max_size_mb * 1024 * 1024:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE,
|
||||
detail=f"File size exceeds limit of {max_size_mb}MB"
|
||||
)
|
||||
|
||||
# Get max daily uploads config
|
||||
result = await db.execute(
|
||||
select(SystemConfig).where(SystemConfig.key == "max_daily_uploads")
|
||||
)
|
||||
config = result.scalar_one_or_none()
|
||||
max_daily = int(config.value) if config else 20
|
||||
|
||||
# Check daily upload count
|
||||
today_start = datetime.utcnow().replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
result = await db.execute(
|
||||
select(func.count(Exam.id)).where(
|
||||
and_(
|
||||
Exam.user_id == user_id,
|
||||
Exam.created_at >= today_start
|
||||
)
|
||||
)
|
||||
)
|
||||
upload_count = result.scalar()
|
||||
|
||||
if upload_count >= max_daily:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||
detail=f"Daily upload limit of {max_daily} reached"
|
||||
)
|
||||
|
||||
|
||||
async def process_questions_with_dedup(
|
||||
exam_id: int,
|
||||
questions_data: List[dict],
|
||||
db: AsyncSession
|
||||
) -> ParseResult:
|
||||
"""
|
||||
Process parsed questions with deduplication logic.
|
||||
|
||||
Args:
|
||||
exam_id: Target exam ID
|
||||
questions_data: List of question dicts from LLM parsing
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
ParseResult with statistics
|
||||
"""
|
||||
total_parsed = len(questions_data)
|
||||
duplicates_removed = 0
|
||||
new_added = 0
|
||||
|
||||
# Get existing content hashes for this exam
|
||||
result = await db.execute(
|
||||
select(Question.content_hash).where(Question.exam_id == exam_id)
|
||||
)
|
||||
existing_hashes = set(row[0] for row in result.all())
|
||||
|
||||
# Insert only new questions
|
||||
for q_data in questions_data:
|
||||
content_hash = q_data.get("content_hash")
|
||||
|
||||
if content_hash in existing_hashes:
|
||||
duplicates_removed += 1
|
||||
continue
|
||||
|
||||
# Create new question
|
||||
new_question = Question(
|
||||
exam_id=exam_id,
|
||||
content=q_data["content"],
|
||||
type=q_data["type"],
|
||||
options=q_data.get("options"),
|
||||
answer=q_data["answer"],
|
||||
analysis=q_data.get("analysis"),
|
||||
content_hash=content_hash
|
||||
)
|
||||
db.add(new_question)
|
||||
existing_hashes.add(content_hash) # Add to set to prevent duplicates in current batch
|
||||
new_added += 1
|
||||
|
||||
await db.commit()
|
||||
|
||||
return ParseResult(
|
||||
total_parsed=total_parsed,
|
||||
duplicates_removed=duplicates_removed,
|
||||
new_added=new_added,
|
||||
message=f"Parsed {total_parsed} questions, removed {duplicates_removed} duplicates, added {new_added} new questions"
|
||||
)
|
||||
|
||||
|
||||
async def async_parse_and_save(
|
||||
exam_id: int,
|
||||
file_content: bytes,
|
||||
filename: str,
|
||||
db_url: str
|
||||
):
|
||||
"""
|
||||
Background task to parse document and save questions with deduplication.
|
||||
"""
|
||||
from database import AsyncSessionLocal
|
||||
from sqlalchemy import select
|
||||
|
||||
async with AsyncSessionLocal() as db:
|
||||
try:
|
||||
# Update exam status to processing
|
||||
result = await db.execute(select(Exam).where(Exam.id == exam_id))
|
||||
exam = result.scalar_one()
|
||||
exam.status = ExamStatus.PROCESSING
|
||||
await db.commit()
|
||||
|
||||
# Parse document
|
||||
print(f"[Exam {exam_id}] Parsing document: {filename}")
|
||||
text_content = await document_parser.parse_file(file_content, filename)
|
||||
|
||||
if not text_content or len(text_content.strip()) < 10:
|
||||
raise Exception("Document appears to be empty or too short")
|
||||
|
||||
# Parse questions using LLM
|
||||
print(f"[Exam {exam_id}] Calling LLM to extract questions...")
|
||||
questions_data = await llm_service.parse_document(text_content)
|
||||
|
||||
if not questions_data:
|
||||
raise Exception("No questions found in document")
|
||||
|
||||
# Process questions with deduplication
|
||||
print(f"[Exam {exam_id}] Processing questions with deduplication...")
|
||||
parse_result = await process_questions_with_dedup(exam_id, questions_data, db)
|
||||
|
||||
# Update exam status and total questions
|
||||
result = await db.execute(select(Exam).where(Exam.id == exam_id))
|
||||
exam = result.scalar_one()
|
||||
|
||||
# Get updated question count
|
||||
result = await db.execute(
|
||||
select(func.count(Question.id)).where(Question.exam_id == exam_id)
|
||||
)
|
||||
total_questions = result.scalar()
|
||||
|
||||
exam.status = ExamStatus.READY
|
||||
exam.total_questions = total_questions
|
||||
await db.commit()
|
||||
|
||||
print(f"[Exam {exam_id}] ✅ {parse_result.message}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"[Exam {exam_id}] ❌ Error: {str(e)}")
|
||||
|
||||
# Update exam status to failed
|
||||
result = await db.execute(select(Exam).where(Exam.id == exam_id))
|
||||
exam = result.scalar_one()
|
||||
exam.status = ExamStatus.FAILED
|
||||
await db.commit()
|
||||
|
||||
|
||||
@router.post("/create", response_model=ExamUploadResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def create_exam_with_upload(
|
||||
background_tasks: BackgroundTasks,
|
||||
title: str = Form(...),
|
||||
file: UploadFile = File(...),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Create a new exam and upload the first document.
|
||||
Document will be parsed asynchronously in background.
|
||||
"""
|
||||
|
||||
# Validate file
|
||||
if not file.filename or not is_allowed_file(file.filename):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid file type. Allowed: txt, pdf, doc, docx, xlsx, xls"
|
||||
)
|
||||
|
||||
# Read file content
|
||||
file_content = await file.read()
|
||||
|
||||
# Check upload limits
|
||||
await check_upload_limits(current_user.id, len(file_content), db)
|
||||
|
||||
# Create exam
|
||||
new_exam = Exam(
|
||||
user_id=current_user.id,
|
||||
title=title,
|
||||
status=ExamStatus.PENDING
|
||||
)
|
||||
db.add(new_exam)
|
||||
await db.commit()
|
||||
await db.refresh(new_exam)
|
||||
|
||||
# Start background parsing
|
||||
background_tasks.add_task(
|
||||
async_parse_and_save,
|
||||
new_exam.id,
|
||||
file_content,
|
||||
file.filename,
|
||||
os.getenv("DATABASE_URL")
|
||||
)
|
||||
|
||||
return ExamUploadResponse(
|
||||
exam_id=new_exam.id,
|
||||
title=new_exam.title,
|
||||
status=new_exam.status.value,
|
||||
message="Exam created. Document is being processed in background."
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{exam_id}/append", response_model=ExamUploadResponse)
|
||||
async def append_document_to_exam(
|
||||
exam_id: int,
|
||||
background_tasks: BackgroundTasks,
|
||||
file: UploadFile = File(...),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Append a new document to an existing exam.
|
||||
Questions will be parsed and deduplicated asynchronously.
|
||||
"""
|
||||
|
||||
# Get exam and verify ownership
|
||||
result = await db.execute(
|
||||
select(Exam).where(
|
||||
and_(Exam.id == exam_id, Exam.user_id == current_user.id)
|
||||
)
|
||||
)
|
||||
exam = result.scalar_one_or_none()
|
||||
|
||||
if not exam:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Exam not found"
|
||||
)
|
||||
|
||||
# Don't allow appending while processing
|
||||
if exam.status == ExamStatus.PROCESSING:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail="Exam is currently being processed. Please wait."
|
||||
)
|
||||
|
||||
# Validate file
|
||||
if not file.filename or not is_allowed_file(file.filename):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid file type. Allowed: txt, pdf, doc, docx, xlsx, xls"
|
||||
)
|
||||
|
||||
# Read file content
|
||||
file_content = await file.read()
|
||||
|
||||
# Check upload limits
|
||||
await check_upload_limits(current_user.id, len(file_content), db)
|
||||
|
||||
# Start background parsing (will auto-deduplicate)
|
||||
background_tasks.add_task(
|
||||
async_parse_and_save,
|
||||
exam.id,
|
||||
file_content,
|
||||
file.filename,
|
||||
os.getenv("DATABASE_URL")
|
||||
)
|
||||
|
||||
return ExamUploadResponse(
|
||||
exam_id=exam.id,
|
||||
title=exam.title,
|
||||
status=ExamStatus.PROCESSING.value,
|
||||
message=f"Document '{file.filename}' is being processed. Duplicates will be automatically removed."
|
||||
)
|
||||
|
||||
|
||||
@router.get("/", response_model=ExamListResponse)
|
||||
async def get_user_exams(
|
||||
skip: int = 0,
|
||||
limit: int = 20,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Get all exams for current user"""
|
||||
|
||||
# Get total count
|
||||
result = await db.execute(
|
||||
select(func.count(Exam.id)).where(Exam.user_id == current_user.id)
|
||||
)
|
||||
total = result.scalar()
|
||||
|
||||
# Get exams
|
||||
result = await db.execute(
|
||||
select(Exam)
|
||||
.where(Exam.user_id == current_user.id)
|
||||
.order_by(Exam.created_at.desc())
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
)
|
||||
exams = result.scalars().all()
|
||||
|
||||
return ExamListResponse(exams=exams, total=total)
|
||||
|
||||
|
||||
@router.get("/{exam_id}", response_model=ExamResponse)
|
||||
async def get_exam_detail(
|
||||
exam_id: int,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Get exam details"""
|
||||
|
||||
result = await db.execute(
|
||||
select(Exam).where(
|
||||
and_(Exam.id == exam_id, Exam.user_id == current_user.id)
|
||||
)
|
||||
)
|
||||
exam = result.scalar_one_or_none()
|
||||
|
||||
if not exam:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Exam not found"
|
||||
)
|
||||
|
||||
return exam
|
||||
|
||||
|
||||
@router.delete("/{exam_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_exam(
|
||||
exam_id: int,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Delete an exam and all its questions"""
|
||||
|
||||
result = await db.execute(
|
||||
select(Exam).where(
|
||||
and_(Exam.id == exam_id, Exam.user_id == current_user.id)
|
||||
)
|
||||
)
|
||||
exam = result.scalar_one_or_none()
|
||||
|
||||
if not exam:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Exam not found"
|
||||
)
|
||||
|
||||
await db.delete(exam)
|
||||
await db.commit()
|
||||
|
||||
|
||||
@router.put("/{exam_id}/progress", response_model=ExamResponse)
|
||||
async def update_quiz_progress(
|
||||
exam_id: int,
|
||||
progress: QuizProgressUpdate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Update quiz progress (current_index)"""
|
||||
|
||||
result = await db.execute(
|
||||
select(Exam).where(
|
||||
and_(Exam.id == exam_id, Exam.user_id == current_user.id)
|
||||
)
|
||||
)
|
||||
exam = result.scalar_one_or_none()
|
||||
|
||||
if not exam:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Exam not found"
|
||||
)
|
||||
|
||||
exam.current_index = progress.current_index
|
||||
await db.commit()
|
||||
await db.refresh(exam)
|
||||
|
||||
return exam
|
||||
192
backend/routers/mistake.py
Normal file
192
backend/routers/mistake.py
Normal file
@@ -0,0 +1,192 @@
|
||||
"""
|
||||
Mistake Router - Handles user mistake book (错题本)
|
||||
"""
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, and_, func
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from database import get_db
|
||||
from models import User, Question, UserMistake, Exam
|
||||
from schemas import MistakeAdd, MistakeResponse, MistakeListResponse
|
||||
from services.auth_service import get_current_user
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/", response_model=MistakeListResponse)
|
||||
async def get_user_mistakes(
|
||||
skip: int = 0,
|
||||
limit: int = 50,
|
||||
exam_id: int = None, # Optional filter by exam
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Get user's mistake book with optional exam filter"""
|
||||
|
||||
# Build query
|
||||
query = (
|
||||
select(UserMistake)
|
||||
.options(selectinload(UserMistake.question))
|
||||
.where(UserMistake.user_id == current_user.id)
|
||||
.order_by(UserMistake.created_at.desc())
|
||||
)
|
||||
|
||||
# Apply exam filter if provided
|
||||
if exam_id is not None:
|
||||
query = query.join(Question).where(Question.exam_id == exam_id)
|
||||
|
||||
# Get total count
|
||||
count_query = select(func.count(UserMistake.id)).where(UserMistake.user_id == current_user.id)
|
||||
if exam_id is not None:
|
||||
count_query = count_query.join(Question).where(Question.exam_id == exam_id)
|
||||
|
||||
result = await db.execute(count_query)
|
||||
total = result.scalar()
|
||||
|
||||
# Get mistakes
|
||||
result = await db.execute(query.offset(skip).limit(limit))
|
||||
mistakes = result.scalars().all()
|
||||
|
||||
# Format response
|
||||
mistake_responses = []
|
||||
for mistake in mistakes:
|
||||
mistake_responses.append(
|
||||
MistakeResponse(
|
||||
id=mistake.id,
|
||||
user_id=mistake.user_id,
|
||||
question_id=mistake.question_id,
|
||||
question=mistake.question,
|
||||
created_at=mistake.created_at
|
||||
)
|
||||
)
|
||||
|
||||
return MistakeListResponse(mistakes=mistake_responses, total=total)
|
||||
|
||||
|
||||
@router.post("/add", response_model=MistakeResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def add_to_mistakes(
|
||||
mistake_data: MistakeAdd,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Manually add a question to mistake book"""
|
||||
|
||||
# Verify question exists and user has access to it
|
||||
result = await db.execute(
|
||||
select(Question)
|
||||
.join(Exam)
|
||||
.where(
|
||||
and_(
|
||||
Question.id == mistake_data.question_id,
|
||||
Exam.user_id == current_user.id
|
||||
)
|
||||
)
|
||||
)
|
||||
question = result.scalar_one_or_none()
|
||||
|
||||
if not question:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Question not found or you don't have access"
|
||||
)
|
||||
|
||||
# Check if already in mistake book
|
||||
result = await db.execute(
|
||||
select(UserMistake).where(
|
||||
and_(
|
||||
UserMistake.user_id == current_user.id,
|
||||
UserMistake.question_id == mistake_data.question_id
|
||||
)
|
||||
)
|
||||
)
|
||||
existing = result.scalar_one_or_none()
|
||||
|
||||
if existing:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail="Question already in mistake book"
|
||||
)
|
||||
|
||||
# Add to mistake book
|
||||
new_mistake = UserMistake(
|
||||
user_id=current_user.id,
|
||||
question_id=mistake_data.question_id
|
||||
)
|
||||
db.add(new_mistake)
|
||||
await db.commit()
|
||||
await db.refresh(new_mistake)
|
||||
|
||||
# Load question relationship
|
||||
result = await db.execute(
|
||||
select(UserMistake)
|
||||
.options(selectinload(UserMistake.question))
|
||||
.where(UserMistake.id == new_mistake.id)
|
||||
)
|
||||
new_mistake = result.scalar_one()
|
||||
|
||||
return MistakeResponse(
|
||||
id=new_mistake.id,
|
||||
user_id=new_mistake.user_id,
|
||||
question_id=new_mistake.question_id,
|
||||
question=new_mistake.question,
|
||||
created_at=new_mistake.created_at
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/{mistake_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def remove_from_mistakes(
|
||||
mistake_id: int,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Remove a question from mistake book"""
|
||||
|
||||
# Get mistake and verify ownership
|
||||
result = await db.execute(
|
||||
select(UserMistake).where(
|
||||
and_(
|
||||
UserMistake.id == mistake_id,
|
||||
UserMistake.user_id == current_user.id
|
||||
)
|
||||
)
|
||||
)
|
||||
mistake = result.scalar_one_or_none()
|
||||
|
||||
if not mistake:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Mistake record not found"
|
||||
)
|
||||
|
||||
await db.delete(mistake)
|
||||
await db.commit()
|
||||
|
||||
|
||||
@router.delete("/question/{question_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def remove_question_from_mistakes(
|
||||
question_id: int,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Remove a question from mistake book by question ID"""
|
||||
|
||||
# Get mistake and verify ownership
|
||||
result = await db.execute(
|
||||
select(UserMistake).where(
|
||||
and_(
|
||||
UserMistake.question_id == question_id,
|
||||
UserMistake.user_id == current_user.id
|
||||
)
|
||||
)
|
||||
)
|
||||
mistake = result.scalar_one_or_none()
|
||||
|
||||
if not mistake:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Question not found in mistake book"
|
||||
)
|
||||
|
||||
await db.delete(mistake)
|
||||
await db.commit()
|
||||
228
backend/routers/question.py
Normal file
228
backend/routers/question.py
Normal file
@@ -0,0 +1,228 @@
|
||||
"""
|
||||
Question Router - Handles quiz playing and answer checking
|
||||
"""
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, and_, func
|
||||
from typing import List, Optional
|
||||
|
||||
from database import get_db
|
||||
from models import User, Exam, Question, UserMistake, ExamStatus, QuestionType
|
||||
from schemas import (
|
||||
QuestionResponse, QuestionListResponse,
|
||||
AnswerSubmit, AnswerCheckResponse
|
||||
)
|
||||
from services.auth_service import get_current_user
|
||||
from services.llm_service import llm_service
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/exam/{exam_id}/questions", response_model=QuestionListResponse)
|
||||
async def get_exam_questions(
|
||||
exam_id: int,
|
||||
skip: int = 0,
|
||||
limit: int = 50,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Get all questions for an exam"""
|
||||
|
||||
# Verify exam ownership
|
||||
result = await db.execute(
|
||||
select(Exam).where(
|
||||
and_(Exam.id == exam_id, Exam.user_id == current_user.id)
|
||||
)
|
||||
)
|
||||
exam = result.scalar_one_or_none()
|
||||
|
||||
if not exam:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Exam not found"
|
||||
)
|
||||
|
||||
# Get total count
|
||||
result = await db.execute(
|
||||
select(func.count(Question.id)).where(Question.exam_id == exam_id)
|
||||
)
|
||||
total = result.scalar()
|
||||
|
||||
# Get questions
|
||||
result = await db.execute(
|
||||
select(Question)
|
||||
.where(Question.exam_id == exam_id)
|
||||
.order_by(Question.id)
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
)
|
||||
questions = result.scalars().all()
|
||||
|
||||
return QuestionListResponse(questions=questions, total=total)
|
||||
|
||||
|
||||
@router.get("/exam/{exam_id}/current", response_model=QuestionResponse)
|
||||
async def get_current_question(
|
||||
exam_id: int,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Get the current question based on exam's current_index"""
|
||||
|
||||
# Get exam
|
||||
result = await db.execute(
|
||||
select(Exam).where(
|
||||
and_(Exam.id == exam_id, Exam.user_id == current_user.id)
|
||||
)
|
||||
)
|
||||
exam = result.scalar_one_or_none()
|
||||
|
||||
if not exam:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Exam not found"
|
||||
)
|
||||
|
||||
if exam.status != ExamStatus.READY:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Exam is not ready. Status: {exam.status.value}"
|
||||
)
|
||||
|
||||
# Get questions
|
||||
result = await db.execute(
|
||||
select(Question)
|
||||
.where(Question.exam_id == exam_id)
|
||||
.order_by(Question.id)
|
||||
.offset(exam.current_index)
|
||||
.limit(1)
|
||||
)
|
||||
question = result.scalar_one_or_none()
|
||||
|
||||
if not question:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="No more questions available. You've completed this exam!"
|
||||
)
|
||||
|
||||
return question
|
||||
|
||||
|
||||
@router.get("/{question_id}", response_model=QuestionResponse)
|
||||
async def get_question_by_id(
|
||||
question_id: int,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Get a specific question by ID"""
|
||||
|
||||
# Get question and verify access through exam ownership
|
||||
result = await db.execute(
|
||||
select(Question)
|
||||
.join(Exam)
|
||||
.where(
|
||||
and_(
|
||||
Question.id == question_id,
|
||||
Exam.user_id == current_user.id
|
||||
)
|
||||
)
|
||||
)
|
||||
question = result.scalar_one_or_none()
|
||||
|
||||
if not question:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Question not found"
|
||||
)
|
||||
|
||||
return question
|
||||
|
||||
|
||||
@router.post("/check", response_model=AnswerCheckResponse)
|
||||
async def check_answer(
|
||||
answer_data: AnswerSubmit,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Check user's answer and return result.
|
||||
For short answers, use AI to grade.
|
||||
Automatically add wrong answers to mistake book.
|
||||
"""
|
||||
|
||||
# Get question and verify access
|
||||
result = await db.execute(
|
||||
select(Question)
|
||||
.join(Exam)
|
||||
.where(
|
||||
and_(
|
||||
Question.id == answer_data.question_id,
|
||||
Exam.user_id == current_user.id
|
||||
)
|
||||
)
|
||||
)
|
||||
question = result.scalar_one_or_none()
|
||||
|
||||
if not question:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Question not found"
|
||||
)
|
||||
|
||||
user_answer = answer_data.user_answer.strip()
|
||||
correct_answer = question.answer.strip()
|
||||
is_correct = False
|
||||
ai_score = None
|
||||
ai_feedback = None
|
||||
|
||||
# Check answer based on question type
|
||||
if question.type == QuestionType.SHORT:
|
||||
# Use AI to grade short answer
|
||||
grading = await llm_service.grade_short_answer(
|
||||
question.content,
|
||||
correct_answer,
|
||||
user_answer
|
||||
)
|
||||
ai_score = grading["score"]
|
||||
ai_feedback = grading["feedback"]
|
||||
is_correct = ai_score >= 0.7 # Consider 70% as correct
|
||||
|
||||
elif question.type == QuestionType.MULTIPLE:
|
||||
# For multiple choice, normalize answer (sort letters)
|
||||
user_normalized = ''.join(sorted(user_answer.upper().replace(' ', '')))
|
||||
correct_normalized = ''.join(sorted(correct_answer.upper().replace(' ', '')))
|
||||
is_correct = user_normalized == correct_normalized
|
||||
|
||||
else:
|
||||
# For single choice and judge questions
|
||||
is_correct = user_answer.upper() == correct_answer.upper()
|
||||
|
||||
# If wrong, add to mistake book
|
||||
if not is_correct:
|
||||
# Check if already in mistake book
|
||||
result = await db.execute(
|
||||
select(UserMistake).where(
|
||||
and_(
|
||||
UserMistake.user_id == current_user.id,
|
||||
UserMistake.question_id == question.id
|
||||
)
|
||||
)
|
||||
)
|
||||
existing_mistake = result.scalar_one_or_none()
|
||||
|
||||
if not existing_mistake:
|
||||
new_mistake = UserMistake(
|
||||
user_id=current_user.id,
|
||||
question_id=question.id
|
||||
)
|
||||
db.add(new_mistake)
|
||||
await db.commit()
|
||||
|
||||
return AnswerCheckResponse(
|
||||
correct=is_correct,
|
||||
user_answer=user_answer,
|
||||
correct_answer=correct_answer,
|
||||
analysis=question.analysis,
|
||||
ai_score=ai_score,
|
||||
ai_feedback=ai_feedback
|
||||
)
|
||||
160
backend/schemas.py
Normal file
160
backend/schemas.py
Normal file
@@ -0,0 +1,160 @@
|
||||
"""
|
||||
Pydantic Schemas for Request/Response Validation
|
||||
"""
|
||||
from pydantic import BaseModel, Field, validator
|
||||
from typing import Optional, List
|
||||
from datetime import datetime
|
||||
from models import ExamStatus, QuestionType
|
||||
|
||||
|
||||
# ============ Auth Schemas ============
|
||||
class UserCreate(BaseModel):
|
||||
username: str = Field(..., min_length=3, max_length=50)
|
||||
password: str = Field(..., min_length=6)
|
||||
|
||||
@validator('username')
|
||||
def username_alphanumeric(cls, v):
|
||||
if not v.replace('_', '').replace('-', '').isalnum():
|
||||
raise ValueError('Username must be alphanumeric (allows _ and -)')
|
||||
return v
|
||||
|
||||
|
||||
class UserLogin(BaseModel):
|
||||
username: str
|
||||
password: str
|
||||
|
||||
|
||||
class Token(BaseModel):
|
||||
access_token: str
|
||||
token_type: str = "bearer"
|
||||
|
||||
|
||||
class UserResponse(BaseModel):
|
||||
id: int
|
||||
username: str
|
||||
is_admin: bool
|
||||
created_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
# ============ System Config Schemas ============
|
||||
class SystemConfigUpdate(BaseModel):
|
||||
allow_registration: Optional[bool] = None
|
||||
max_upload_size_mb: Optional[int] = None
|
||||
max_daily_uploads: Optional[int] = None
|
||||
ai_provider: Optional[str] = None
|
||||
|
||||
|
||||
class SystemConfigResponse(BaseModel):
|
||||
allow_registration: bool
|
||||
max_upload_size_mb: int
|
||||
max_daily_uploads: int
|
||||
ai_provider: str
|
||||
|
||||
|
||||
# ============ Exam Schemas ============
|
||||
class ExamCreate(BaseModel):
|
||||
title: str = Field(..., min_length=1, max_length=200)
|
||||
|
||||
|
||||
class ExamResponse(BaseModel):
|
||||
id: int
|
||||
user_id: int
|
||||
title: str
|
||||
status: ExamStatus
|
||||
current_index: int
|
||||
total_questions: int
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class ExamListResponse(BaseModel):
|
||||
exams: List[ExamResponse]
|
||||
total: int
|
||||
|
||||
|
||||
class ExamUploadResponse(BaseModel):
|
||||
exam_id: int
|
||||
title: str
|
||||
status: str
|
||||
message: str
|
||||
|
||||
|
||||
class ParseResult(BaseModel):
|
||||
"""Result from file parsing"""
|
||||
total_parsed: int
|
||||
duplicates_removed: int
|
||||
new_added: int
|
||||
message: str
|
||||
|
||||
|
||||
# ============ Question Schemas ============
|
||||
class QuestionBase(BaseModel):
|
||||
content: str
|
||||
type: QuestionType
|
||||
options: Optional[List[str]] = None
|
||||
answer: str
|
||||
analysis: Optional[str] = None
|
||||
|
||||
|
||||
class QuestionCreate(QuestionBase):
|
||||
exam_id: int
|
||||
|
||||
|
||||
class QuestionResponse(QuestionBase):
|
||||
id: int
|
||||
exam_id: int
|
||||
created_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class QuestionListResponse(BaseModel):
|
||||
questions: List[QuestionResponse]
|
||||
total: int
|
||||
|
||||
|
||||
# ============ Quiz Schemas ============
|
||||
class AnswerSubmit(BaseModel):
|
||||
question_id: int
|
||||
user_answer: str
|
||||
|
||||
|
||||
class AnswerCheckResponse(BaseModel):
|
||||
correct: bool
|
||||
user_answer: str
|
||||
correct_answer: str
|
||||
analysis: Optional[str] = None
|
||||
ai_score: Optional[float] = None # For short answer questions
|
||||
ai_feedback: Optional[str] = None # For short answer questions
|
||||
|
||||
|
||||
class QuizProgressUpdate(BaseModel):
|
||||
current_index: int
|
||||
|
||||
|
||||
# ============ Mistake Schemas ============
|
||||
class MistakeAdd(BaseModel):
|
||||
question_id: int
|
||||
|
||||
|
||||
class MistakeResponse(BaseModel):
|
||||
id: int
|
||||
user_id: int
|
||||
question_id: int
|
||||
question: QuestionResponse
|
||||
created_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class MistakeListResponse(BaseModel):
|
||||
mistakes: List[MistakeResponse]
|
||||
total: int
|
||||
13
backend/services/__init__.py
Normal file
13
backend/services/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
"""
|
||||
Services package
|
||||
"""
|
||||
from .auth_service import get_current_user, get_current_admin_user
|
||||
from .llm_service import llm_service
|
||||
from .document_parser import document_parser
|
||||
|
||||
__all__ = [
|
||||
"get_current_user",
|
||||
"get_current_admin_user",
|
||||
"llm_service",
|
||||
"document_parser"
|
||||
]
|
||||
78
backend/services/auth_service.py
Normal file
78
backend/services/auth_service.py
Normal file
@@ -0,0 +1,78 @@
|
||||
"""
|
||||
Authentication Service
|
||||
"""
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
from typing import Optional
|
||||
|
||||
from models import User
|
||||
from database import get_db
|
||||
from utils import decode_access_token
|
||||
|
||||
security = HTTPBearer()
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
credentials: HTTPAuthorizationCredentials = Depends(security),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
) -> User:
|
||||
"""
|
||||
Get current authenticated user from JWT token.
|
||||
"""
|
||||
token = credentials.credentials
|
||||
credentials_exception = HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Could not validate credentials",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
# Decode token
|
||||
payload = decode_access_token(token)
|
||||
if payload is None:
|
||||
raise credentials_exception
|
||||
|
||||
user_id: int = payload.get("sub")
|
||||
if user_id is None:
|
||||
raise credentials_exception
|
||||
|
||||
# Get user from database
|
||||
result = await db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
if user is None:
|
||||
raise credentials_exception
|
||||
|
||||
return user
|
||||
|
||||
|
||||
async def get_current_admin_user(
|
||||
current_user: User = Depends(get_current_user)
|
||||
) -> User:
|
||||
"""
|
||||
Get current user and verify admin permissions.
|
||||
"""
|
||||
if not current_user.is_admin:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Not enough permissions"
|
||||
)
|
||||
return current_user
|
||||
|
||||
|
||||
async def get_optional_user(
|
||||
credentials: Optional[HTTPAuthorizationCredentials] = Depends(security),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
) -> Optional[User]:
|
||||
"""
|
||||
Get current user if token is provided, otherwise return None.
|
||||
Useful for endpoints that work for both authenticated and anonymous users.
|
||||
"""
|
||||
if credentials is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
return await get_current_user(credentials, db)
|
||||
except HTTPException:
|
||||
return None
|
||||
121
backend/services/document_parser.py
Normal file
121
backend/services/document_parser.py
Normal file
@@ -0,0 +1,121 @@
|
||||
"""
|
||||
Document Parser Service
|
||||
Supports: TXT, PDF, DOCX, XLSX
|
||||
"""
|
||||
import io
|
||||
from typing import Optional
|
||||
import PyPDF2
|
||||
from docx import Document
|
||||
import openpyxl
|
||||
|
||||
|
||||
class DocumentParser:
|
||||
"""Parse various document formats to extract text content"""
|
||||
|
||||
@staticmethod
|
||||
async def parse_txt(file_content: bytes) -> str:
|
||||
"""Parse TXT file"""
|
||||
try:
|
||||
return file_content.decode('utf-8')
|
||||
except UnicodeDecodeError:
|
||||
try:
|
||||
return file_content.decode('gbk')
|
||||
except:
|
||||
return file_content.decode('utf-8', errors='ignore')
|
||||
|
||||
@staticmethod
|
||||
async def parse_pdf(file_content: bytes) -> str:
|
||||
"""Parse PDF file"""
|
||||
try:
|
||||
pdf_file = io.BytesIO(file_content)
|
||||
pdf_reader = PyPDF2.PdfReader(pdf_file)
|
||||
|
||||
text_content = []
|
||||
for page in pdf_reader.pages:
|
||||
text = page.extract_text()
|
||||
if text:
|
||||
text_content.append(text)
|
||||
|
||||
return '\n\n'.join(text_content)
|
||||
except Exception as e:
|
||||
raise Exception(f"Failed to parse PDF: {str(e)}")
|
||||
|
||||
@staticmethod
|
||||
async def parse_docx(file_content: bytes) -> str:
|
||||
"""Parse DOCX file"""
|
||||
try:
|
||||
docx_file = io.BytesIO(file_content)
|
||||
doc = Document(docx_file)
|
||||
|
||||
text_content = []
|
||||
for paragraph in doc.paragraphs:
|
||||
if paragraph.text.strip():
|
||||
text_content.append(paragraph.text)
|
||||
|
||||
# Also extract text from tables
|
||||
for table in doc.tables:
|
||||
for row in table.rows:
|
||||
row_text = ' | '.join(cell.text.strip() for cell in row.cells)
|
||||
if row_text.strip():
|
||||
text_content.append(row_text)
|
||||
|
||||
return '\n\n'.join(text_content)
|
||||
except Exception as e:
|
||||
raise Exception(f"Failed to parse DOCX: {str(e)}")
|
||||
|
||||
@staticmethod
|
||||
async def parse_xlsx(file_content: bytes) -> str:
|
||||
"""Parse XLSX file"""
|
||||
try:
|
||||
xlsx_file = io.BytesIO(file_content)
|
||||
workbook = openpyxl.load_workbook(xlsx_file, data_only=True)
|
||||
|
||||
text_content = []
|
||||
for sheet_name in workbook.sheetnames:
|
||||
sheet = workbook[sheet_name]
|
||||
text_content.append(f"=== Sheet: {sheet_name} ===")
|
||||
|
||||
for row in sheet.iter_rows(values_only=True):
|
||||
row_text = ' | '.join(str(cell) if cell is not None else '' for cell in row)
|
||||
if row_text.strip(' |'):
|
||||
text_content.append(row_text)
|
||||
|
||||
return '\n\n'.join(text_content)
|
||||
except Exception as e:
|
||||
raise Exception(f"Failed to parse XLSX: {str(e)}")
|
||||
|
||||
@staticmethod
|
||||
async def parse_file(file_content: bytes, filename: str) -> str:
|
||||
"""
|
||||
Parse file based on extension.
|
||||
|
||||
Args:
|
||||
file_content: File content as bytes
|
||||
filename: Original filename
|
||||
|
||||
Returns:
|
||||
Extracted text content
|
||||
|
||||
Raises:
|
||||
Exception: If file format is unsupported or parsing fails
|
||||
"""
|
||||
extension = filename.rsplit('.', 1)[-1].lower() if '.' in filename else ''
|
||||
|
||||
parsers = {
|
||||
'txt': DocumentParser.parse_txt,
|
||||
'pdf': DocumentParser.parse_pdf,
|
||||
'docx': DocumentParser.parse_docx,
|
||||
'doc': DocumentParser.parse_docx, # Try to parse DOC as DOCX
|
||||
'xlsx': DocumentParser.parse_xlsx,
|
||||
'xls': DocumentParser.parse_xlsx, # Try to parse XLS as XLSX
|
||||
}
|
||||
|
||||
parser = parsers.get(extension)
|
||||
if not parser:
|
||||
raise Exception(f"Unsupported file format: {extension}")
|
||||
|
||||
return await parser(file_content)
|
||||
|
||||
|
||||
# Singleton instance
|
||||
document_parser = DocumentParser()
|
||||
216
backend/services/llm_service.py
Normal file
216
backend/services/llm_service.py
Normal file
@@ -0,0 +1,216 @@
|
||||
"""
|
||||
LLM Service for AI-powered question parsing and grading
|
||||
"""
|
||||
import os
|
||||
import json
|
||||
from typing import List, Dict, Any, Optional
|
||||
from openai import AsyncOpenAI
|
||||
from anthropic import AsyncAnthropic
|
||||
import httpx
|
||||
|
||||
from models import QuestionType
|
||||
from utils import calculate_content_hash
|
||||
|
||||
|
||||
class LLMService:
|
||||
"""Service for interacting with various LLM providers"""
|
||||
|
||||
def __init__(self):
|
||||
self.provider = os.getenv("AI_PROVIDER", "openai")
|
||||
|
||||
if self.provider == "openai":
|
||||
self.client = AsyncOpenAI(
|
||||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
base_url=os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1")
|
||||
)
|
||||
self.model = os.getenv("OPENAI_MODEL", "gpt-4o-mini")
|
||||
|
||||
elif self.provider == "anthropic":
|
||||
self.client = AsyncAnthropic(
|
||||
api_key=os.getenv("ANTHROPIC_API_KEY")
|
||||
)
|
||||
self.model = os.getenv("ANTHROPIC_MODEL", "claude-3-haiku-20240307")
|
||||
|
||||
elif self.provider == "qwen":
|
||||
self.client = AsyncOpenAI(
|
||||
api_key=os.getenv("QWEN_API_KEY"),
|
||||
base_url=os.getenv("QWEN_BASE_URL", "https://dashscope.aliyuncs.com/compatible-mode/v1")
|
||||
)
|
||||
self.model = os.getenv("QWEN_MODEL", "qwen-plus")
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported AI provider: {self.provider}")
|
||||
|
||||
async def parse_document(self, content: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Parse document content and extract questions.
|
||||
|
||||
Returns a list of dictionaries with question data:
|
||||
[
|
||||
{
|
||||
"content": "Question text",
|
||||
"type": "single/multiple/judge/short",
|
||||
"options": ["A. Option1", "B. Option2", ...], # For choice questions
|
||||
"answer": "Correct answer",
|
||||
"analysis": "Explanation"
|
||||
},
|
||||
...
|
||||
]
|
||||
"""
|
||||
prompt = """You are a professional question parser. Parse the given document and extract all questions.
|
||||
|
||||
For each question, identify:
|
||||
1. Question content (the question text)
|
||||
2. Question type: single (单选), multiple (多选), judge (判断), short (简答)
|
||||
3. Options (for choice questions only, format: ["A. Option1", "B. Option2", ...])
|
||||
4. Correct answer
|
||||
5. Analysis/Explanation (if available)
|
||||
|
||||
Return ONLY a JSON array of questions, with no additional text:
|
||||
[
|
||||
{
|
||||
"content": "question text",
|
||||
"type": "single",
|
||||
"options": ["A. Option1", "B. Option2", "C. Option3", "D. Option4"],
|
||||
"answer": "A",
|
||||
"analysis": "explanation"
|
||||
},
|
||||
...
|
||||
]
|
||||
|
||||
Document content:
|
||||
---
|
||||
{content}
|
||||
---
|
||||
|
||||
IMPORTANT: Return ONLY the JSON array, no markdown code blocks or explanations."""
|
||||
|
||||
try:
|
||||
if self.provider == "anthropic":
|
||||
response = await self.client.messages.create(
|
||||
model=self.model,
|
||||
max_tokens=4096,
|
||||
messages=[
|
||||
{"role": "user", "content": prompt.format(content=content)}
|
||||
]
|
||||
)
|
||||
result = response.content[0].text
|
||||
else: # OpenAI or Qwen
|
||||
response = await self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a professional question parser. Return only JSON."},
|
||||
{"role": "user", "content": prompt.format(content=content)}
|
||||
],
|
||||
temperature=0.3,
|
||||
)
|
||||
result = response.choices[0].message.content
|
||||
|
||||
# Clean result and parse JSON
|
||||
result = result.strip()
|
||||
if result.startswith("```json"):
|
||||
result = result[7:]
|
||||
if result.startswith("```"):
|
||||
result = result[3:]
|
||||
if result.endswith("```"):
|
||||
result = result[:-3]
|
||||
result = result.strip()
|
||||
|
||||
questions = json.loads(result)
|
||||
|
||||
# Add content hash to each question
|
||||
for q in questions:
|
||||
q["content_hash"] = calculate_content_hash(q["content"])
|
||||
|
||||
return questions
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error parsing document: {e}")
|
||||
raise Exception(f"Failed to parse document: {str(e)}")
|
||||
|
||||
async def grade_short_answer(
|
||||
self,
|
||||
question: str,
|
||||
correct_answer: str,
|
||||
user_answer: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Grade a short answer question using AI.
|
||||
|
||||
Returns:
|
||||
{
|
||||
"score": 0.0-1.0,
|
||||
"feedback": "Detailed feedback"
|
||||
}
|
||||
"""
|
||||
prompt = f"""Grade the following short answer question.
|
||||
|
||||
Question: {question}
|
||||
|
||||
Standard Answer: {correct_answer}
|
||||
|
||||
Student Answer: {user_answer}
|
||||
|
||||
Provide a score from 0.0 to 1.0 (where 1.0 is perfect) and detailed feedback.
|
||||
|
||||
Return ONLY a JSON object:
|
||||
{{
|
||||
"score": 0.85,
|
||||
"feedback": "Your detailed feedback here"
|
||||
}}
|
||||
|
||||
Be fair but strict. Consider:
|
||||
1. Correctness of key points
|
||||
2. Completeness of answer
|
||||
3. Clarity of expression
|
||||
|
||||
Return ONLY the JSON object, no markdown or explanations."""
|
||||
|
||||
try:
|
||||
if self.provider == "anthropic":
|
||||
response = await self.client.messages.create(
|
||||
model=self.model,
|
||||
max_tokens=1024,
|
||||
messages=[
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
)
|
||||
result = response.content[0].text
|
||||
else: # OpenAI or Qwen
|
||||
response = await self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a fair and strict grader. Return only JSON."},
|
||||
{"role": "user", "content": prompt}
|
||||
],
|
||||
temperature=0.5,
|
||||
)
|
||||
result = response.choices[0].message.content
|
||||
|
||||
# Clean and parse JSON
|
||||
result = result.strip()
|
||||
if result.startswith("```json"):
|
||||
result = result[7:]
|
||||
if result.startswith("```"):
|
||||
result = result[3:]
|
||||
if result.endswith("```"):
|
||||
result = result[:-3]
|
||||
result = result.strip()
|
||||
|
||||
grading = json.loads(result)
|
||||
return {
|
||||
"score": float(grading.get("score", 0.0)),
|
||||
"feedback": grading.get("feedback", "")
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error grading answer: {e}")
|
||||
# Return default grading on error
|
||||
return {
|
||||
"score": 0.0,
|
||||
"feedback": "Unable to grade answer due to an error."
|
||||
}
|
||||
|
||||
|
||||
# Singleton instance
|
||||
llm_service = LLMService()
|
||||
83
backend/utils.py
Normal file
83
backend/utils.py
Normal file
@@ -0,0 +1,83 @@
|
||||
"""
|
||||
Utility functions
|
||||
"""
|
||||
import hashlib
|
||||
import re
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional
|
||||
from jose import JWTError, jwt
|
||||
from passlib.context import CryptContext
|
||||
import os
|
||||
|
||||
# Password hashing
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
# JWT settings
|
||||
SECRET_KEY = os.getenv("SECRET_KEY", "your-secret-key-change-this")
|
||||
ALGORITHM = "HS256"
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES = 60 * 24 * 7 # 7 days
|
||||
|
||||
|
||||
def hash_password(password: str) -> str:
|
||||
"""Hash a password"""
|
||||
return pwd_context.hash(password)
|
||||
|
||||
|
||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
"""Verify a password against a hash"""
|
||||
return pwd_context.verify(plain_password, hashed_password)
|
||||
|
||||
|
||||
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
|
||||
"""Create a JWT access token"""
|
||||
to_encode = data.copy()
|
||||
if expires_delta:
|
||||
expire = datetime.utcnow() + expires_delta
|
||||
else:
|
||||
expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
|
||||
to_encode.update({"exp": expire})
|
||||
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
||||
return encoded_jwt
|
||||
|
||||
|
||||
def decode_access_token(token: str) -> Optional[dict]:
|
||||
"""Decode a JWT access token"""
|
||||
try:
|
||||
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
return payload
|
||||
except JWTError:
|
||||
return None
|
||||
|
||||
|
||||
def normalize_content(content: str) -> str:
|
||||
"""
|
||||
Normalize content for deduplication.
|
||||
Removes whitespace, punctuation, and converts to lowercase.
|
||||
"""
|
||||
# Remove all whitespace
|
||||
normalized = re.sub(r'\s+', '', content)
|
||||
# Remove punctuation
|
||||
normalized = re.sub(r'[^\w\u4e00-\u9fff]', '', normalized)
|
||||
# Convert to lowercase
|
||||
normalized = normalized.lower()
|
||||
return normalized
|
||||
|
||||
|
||||
def calculate_content_hash(content: str) -> str:
|
||||
"""
|
||||
Calculate MD5 hash of normalized content for deduplication.
|
||||
"""
|
||||
normalized = normalize_content(content)
|
||||
return hashlib.md5(normalized.encode('utf-8')).hexdigest()
|
||||
|
||||
|
||||
def get_file_extension(filename: str) -> str:
|
||||
"""Get file extension from filename"""
|
||||
return filename.rsplit('.', 1)[-1].lower() if '.' in filename else ''
|
||||
|
||||
|
||||
def is_allowed_file(filename: str) -> bool:
|
||||
"""Check if file extension is allowed"""
|
||||
allowed_extensions = {'txt', 'pdf', 'doc', 'docx', 'xlsx', 'xls'}
|
||||
return get_file_extension(filename) in allowed_extensions
|
||||
Reference in New Issue
Block a user