name: fastapi description: 使用FastAPI进行REST API和WebSocket开发,强调安全性、性能和异步模式 model: sonnet risk_level: HIGH
FastAPI开发技能
文件组织
- SKILL.md: 核心原则、模式、基本安全性(本文件)
- references/security-examples.md: CVE详情和OWASP实现
- references/advanced-patterns.md: 高级FastAPI模式
- references/threat-model.md: 攻击场景和STRIDE分析
验证门
门0.2:漏洞研究(高风险阻塞)
- 状态: 通过(记录5+个CVE)
- 研究日期: 2025-11-20
- CVE: CVE-2024-47874, CVE-2024-12868, CVE-2023-30798, Starlette DoS变种
1. 概述
风险等级: HIGH
理由: FastAPI应用处理认证、数据库访问、文件上传和外部API通信。Starlette中的DoS漏洞、注入风险和不当验证可能危及可用性和安全性。
您是一位专家FastAPI开发者,创建安全、高性能的REST API和WebSocket服务。您配置适当的验证、认证和安全头部。
核心专业领域
- Pydantic验证和依赖注入
- 认证:OAuth2、JWT、API密钥
- 安全头部和CORS配置
- 速率限制和DoS保护
- 使用异步ORM进行数据库集成
- WebSocket安全
2. 核心职责
基本原则
- 测试驱动开发优先: 在实现代码前编写测试
- 性能意识: 连接池、缓存、异步模式
- 验证一切: 使用Pydantic模型处理所有输入
- 默认安全: HTTPS、安全头部、严格CORS
- 速率限制: 保护所有端点免受滥用
- 认证与授权: 验证身份和权限
- 安全处理错误: 绝不泄露内部细节
3. 技术基础
版本推荐
| 组件 | 版本 | 备注 |
|---|---|---|
| FastAPI | 0.115.3+ | CVE-2024-47874修复 |
| Starlette | 0.40.0+ | DoS漏洞修复 |
| Pydantic | 2.0+ | 更好的验证 |
| Python | 3.11+ | 性能 |
安全依赖
[project]
dependencies = [
"fastapi>=0.115.3",
"starlette>=0.40.0",
"pydantic>=2.5",
"python-jose[cryptography]>=3.3",
"passlib[argon2]>=1.7",
"python-multipart>=0.0.6",
"slowapi>=0.1.9",
"secure>=0.3",
]
4. 实现模式
模式1:安全应用设置
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from secure import SecureHeaders
app = FastAPI(
title="安全API",
docs_url=None if PRODUCTION else "/docs", # 生产中禁用
redoc_url=None,
)
# 安全头部
secure_headers = SecureHeaders()
@app.middleware("http")
async def add_security_headers(request, call_next):
response = await call_next(request)
secure_headers.framework.fastapi(response)
return response
# 限制性CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["https://app.example.com"], # 永不使用["*"]
allow_credentials=True,
allow_methods=["GET", "POST", "PUT", "DELETE"],
allow_headers=["Authorization", "Content-Type"],
)
模式2:输入验证
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field, field_validator, EmailStr
class UserCreate(BaseModel):
username: str = Field(min_length=3, max_length=50, pattern=r'^[a-zA-Z0-9_-]+$')
email: EmailStr
password: str = Field(min_length=12)
@field_validator('password')
@classmethod
def validate_password(cls, v):
if not any(c.isupper() for c in v):
raise ValueError('必须包含大写字母')
if not any(c.isdigit() for c in v):
raise ValueError('必须包含数字')
return v
@app.post("/users")
async def create_user(user: UserCreate):
# 输入已由Pydantic验证
return await user_service.create(user)
模式3:JWT认证
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from jose import JWTError, jwt
from datetime import datetime, timedelta
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
SECRET_KEY = os.environ["JWT_SECRET"]
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30
def create_access_token(data: dict) -> str:
to_encode = data.copy()
expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
to_encode.update({"exp": expire})
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
async def get_current_user(token: str = Depends(oauth2_scheme)) -> User:
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="无法验证凭据",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
user_id: str = payload.get("sub")
if user_id is None:
raise credentials_exception
except JWTError:
raise credentials_exception
user = await user_service.get(user_id)
if user is None:
raise credentials_exception
return user
模式4:速率限制
from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.util import get_remote_address
from slowapi.errors import RateLimitExceeded
limiter = Limiter(key_func=get_remote_address)
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
@app.post("/login")
@limiter.limit("5/分钟") # 对认证端点严格
async def login(request: Request, credentials: LoginRequest):
return await auth_service.login(credentials)
@app.get("/data")
@limiter.limit("100/分钟")
async def get_data(request: Request):
return await data_service.get_all()
模式5:安全文件上传
from fastapi import UploadFile, File, HTTPException
import magic
ALLOWED_TYPES = {"image/jpeg", "image/png", "application/pdf"}
MAX_SIZE = 10 * 1024 * 1024 # 10MB
@app.post("/upload")
async def upload_file(file: UploadFile = File(...)):
# 检查大小
content = await file.read()
if len(content) > MAX_SIZE:
raise HTTPException(400, "文件过大")
# 检查魔术字节,不仅仅是扩展名
mime_type = magic.from_buffer(content, mime=True)
if mime_type not in ALLOWED_TYPES:
raise HTTPException(400, f"文件类型不允许: {mime_type}")
# 生成安全文件名
safe_name = f"{uuid4()}{Path(file.filename).suffix}"
# 存储在Web根目录外
file_path = UPLOAD_DIR / safe_name
file_path.write_bytes(content)
return {"filename": safe_name}
5. 实现工作流程(TDD)
步骤1:先写失败测试
始终从定义预期行为的测试开始:
import pytest
from httpx import AsyncClient, ASGITransport
from app.main import app
@pytest.mark.asyncio
async def test_create_item_success():
"""测试有效数据的成功物品创建。"""
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
response = await client.post(
"/items",
json={"name": "测试物品", "price": 29.99},
headers={"Authorization": "Bearer valid_token"}
)
assert response.status_code == 201
data = response.json()
assert data["name"] == "测试物品"
assert "id" in data
@pytest.mark.asyncio
async def test_create_item_validation_error():
"""测试验证拒绝无效价格。"""
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
response = await client.post(
"/items",
json={"name": "测试", "price": -10},
headers={"Authorization": "Bearer valid_token"}
)
assert response.status_code == 422
@pytest.mark.asyncio
async def test_create_item_unauthorized():
"""测试端点需要认证。"""
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
response = await client.post("/items", json={"name": "测试", "price": 10})
assert response.status_code == 401
步骤2:实现最小通过代码
仅编写使测试通过的代码:
@app.post("/items", status_code=201)
async def create_item(
item: ItemCreate,
user: User = Depends(get_current_user)
) -> ItemResponse:
created = await item_service.create(item, user.id)
return ItemResponse.from_orm(created)
步骤3:如需重构
在保持测试通过的同时改进代码质量。提取通用模式、改进命名、优化查询。
步骤4:运行完整验证
# 运行所有测试,带覆盖率
pytest --cov=app --cov-report=term-missing
# 类型检查
mypy app --strict
# 安全扫描
bandit -r app -ll
# 提交前必须全部通过
6. 性能模式
模式1:数据库连接池
# 不好 - 每个请求创建新连接
@app.get("/users/{user_id}")
async def get_user(user_id: int):
conn = await asyncpg.connect(DATABASE_URL)
try:
return await conn.fetchrow("SELECT * FROM users WHERE id = $1", user_id)
finally:
await conn.close()
# 好 - 使用连接池
from contextlib import asynccontextmanager
pool: asyncpg.Pool = None
@asynccontextmanager
async def lifespan(app: FastAPI):
global pool
pool = await asyncpg.create_pool(
DATABASE_URL,
min_size=5,
max_size=20,
command_timeout=60
)
yield
await pool.close()
app = FastAPI(lifespan=lifespan)
@app.get("/users/{user_id}")
async def get_user(user_id: int):
async with pool.acquire() as conn:
return await conn.fetchrow("SELECT * FROM users WHERE id = $1", user_id)
模式2:使用asyncio.gather的并发请求
# 不好 - 顺序外部API调用
@app.get("/dashboard")
async def get_dashboard(user_id: int):
profile = await fetch_profile(user_id) # 100ms
orders = await fetch_orders(user_id) # 150ms
notifications = await fetch_notifications(user_id) # 80ms
return {"profile": profile, "orders": orders, "notifications": notifications}
# 总计: ~330ms
# 好 - 并发调用
@app.get("/dashboard")
async def get_dashboard(user_id: int):
profile, orders, notifications = await asyncio.gather(
fetch_profile(user_id),
fetch_orders(user_id),
fetch_notifications(user_id)
)
return {"profile": profile, "orders": orders, "notifications": notifications}
# 总计: ~150ms(最慢调用)
模式3:响应缓存
# 不好 - 每次请求重新计算昂贵数据
@app.get("/stats")
async def get_stats():
return await compute_expensive_stats() # 每次500ms
# 好 - 使用Redis缓存
from fastapi_cache import FastAPICache
from fastapi_cache.backends.redis import RedisBackend
from fastapi_cache.decorator import cache
@asynccontextmanager
async def lifespan(app: FastAPI):
redis = aioredis.from_url("redis://localhost")
FastAPICache.init(RedisBackend(redis), prefix="api-cache")
yield
@app.get("/stats")
@cache(expire=300) # 缓存5分钟
async def get_stats():
return await compute_expensive_stats()
# 好 - 简单情况下的内存缓存
from functools import lru_cache
from datetime import datetime, timedelta
_cache = {}
_cache_time = {}
async def get_cached_config(key: str, ttl: int = 60):
now = datetime.utcnow()
if key in _cache and _cache_time[key] > now:
return _cache[key]
value = await fetch_config(key)
_cache[key] = value
_cache_time[key] = now + timedelta(seconds=ttl)
return value
模式4:大型数据集的分页
# 不好 - 返回所有记录
@app.get("/items")
async def list_items():
return await db.fetch("SELECT * FROM items") # 可能数百万
# 好 - 基于游标的分页
from pydantic import BaseModel
class PaginatedResponse(BaseModel):
items: list
next_cursor: str | None
has_more: bool
@app.get("/items")
async def list_items(
cursor: str | None = None,
limit: int = Query(default=20, le=100)
) -> PaginatedResponse:
query = "SELECT * FROM items"
params = []
if cursor:
query += " WHERE id > $1"
params.append(decode_cursor(cursor))
query += f" ORDER BY id LIMIT {limit + 1}"
rows = await db.fetch(query, *params)
has_more = len(rows) > limit
items = rows[:limit]
return PaginatedResponse(
items=items,
next_cursor=encode_cursor(items[-1]["id"]) if items else None,
has_more=has_more
)
模式5:重型操作的背景任务
# 不好 - 阻塞慢操作的响应
@app.post("/reports")
async def create_report(request: ReportRequest):
report = await generate_report(request) # 花费30秒
await send_email(request.email, report)
return {"status": "完成"}
# 好 - 立即返回,在背景处理
from fastapi import BackgroundTasks
@app.post("/reports", status_code=202)
async def create_report(
request: ReportRequest,
background_tasks: BackgroundTasks
):
report_id = str(uuid4())
background_tasks.add_task(process_report, report_id, request)
return {"report_id": report_id, "status": "处理中"}
async def process_report(report_id: str, request: ReportRequest):
report = await generate_report(request)
await save_report(report_id, report)
await send_email(request.email, report)
@app.get("/reports/{report_id}")
async def get_report_status(report_id: str):
return await get_report(report_id)
7. 安全标准
7.1 域漏洞景观
| CVE ID | 严重性 | 描述 | 缓解措施 |
|---|---|---|---|
| CVE-2024-47874 | HIGH | Starlette多部分DoS通过内存耗尽 | 升级Starlette 0.40.0+ |
| CVE-2024-12868 | HIGH | 通过fastapi依赖的下游DoS | 升级FastAPI 0.115.3+ |
| CVE-2023-30798 | HIGH | Starlette <0.25 DoS | 升级FastAPI 0.92+ |
7.2 OWASP Top 10 映射
| 类别 | 风险 | 缓解措施 |
|---|---|---|
| A01 访问控制 | HIGH | 依赖注入用于认证,权限装饰器 |
| A02 加密失败 | HIGH | JWT使用适当算法,Argon2密码 |
| A03 注入 | HIGH | Pydantic验证,参数化查询 |
| A04 不安全设计 | MEDIUM | 类型安全,验证层 |
| A05 错误配置 | HIGH | 安全头部,生产中禁用文档 |
| A06 易受攻击组件 | CRITICAL | 保持Starlette/FastAPI更新 |
| A07 认证失败 | HIGH | 认证速率限制,安全JWT |
7.3 错误处理
from fastapi import HTTPException
from fastapi.responses import JSONResponse
import logging
logger = logging.getLogger(__name__)
@app.exception_handler(Exception)
async def global_exception_handler(request, exc):
# 记录完整细节
logger.error(f"未处理错误: {exc}", exc_info=True)
# 返回安全消息
return JSONResponse(
status_code=500,
content={"detail": "内部服务器错误"}
)
@app.exception_handler(HTTPException)
async def http_exception_handler(request, exc):
return JSONResponse(
status_code=exc.status_code,
content={"detail": exc.detail}
)
6. 测试与验证
安全测试
import pytest
from fastapi.testclient import TestClient
def test_rate_limiting():
client = TestClient(app)
# 超出速率限制
for _ in range(10):
response = client.post("/login", json={"username": "test", "password": "test"})
assert response.status_code == 429
def test_invalid_jwt_rejected():
client = TestClient(app)
response = client.get(
"/protected",
headers={"Authorization": "Bearer invalid.token.here"}
)
assert response.status_code == 401
def test_sql_injection_prevented():
client = TestClient(app)
response = client.get("/users", params={"search": "'; DROP TABLE users; --"})
assert response.status_code in [200, 400]
# 不应导致500(SQL错误)
def test_file_upload_type_validation():
client = TestClient(app)
# 尝试上传伪装为图像的可执行文件
response = client.post(
"/upload",
files={"file": ("test.jpg", b"MZ\x90\x00", "image/jpeg")} # EXE魔术字节
)
assert response.status_code == 400
8. 常见错误与反模式
反模式1:宽松CORS
# 永不
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True)
# 始终
app.add_middleware(CORSMiddleware, allow_origins=["https://app.example.com"])
反模式2:无速率限制
# 永不 - 允许暴力破解
@app.post("/login")
async def login(creds): ...
# 始终
@app.post("/login")
@limiter.limit("5/分钟")
async def login(request, creds): ...
反模式3:生产中暴露文档
# 永不
app = FastAPI()
# 始终
app = FastAPI(
docs_url=None if os.environ.get("ENV") == "production" else "/docs",
redoc_url=None
)
反模式4:弱JWT配置
# 永不
jwt.encode(data, "secret", algorithm="HS256") # 硬编码弱密钥
# 始终
jwt.encode(data, os.environ["JWT_SECRET"], algorithm="RS256") # 环境变量,强算法
反模式5:仅文件扩展名验证
# 永不
if file.filename.endswith('.jpg'): ...
# 始终
mime = magic.from_buffer(content, mime=True)
if mime not in ALLOWED_TYPES: ...
13. 部署前清单
- [ ] FastAPI 0.115.3+ / Starlette 0.40.0+
- [ ] 配置安全头部中间件
- [ ] 限制性CORS(不使用带凭证的通配符)
- [ ] 所有端点速率限制
- [ ] 认证端点更严格限制
- [ ] JWT使用来自环境的强密钥
- [ ] 所有输入Pydantic验证
- [ ] 文件上传检查魔术字节
- [ ] 生产中禁用文档
- [ ] 错误处理不泄露内部细节
- [ ] 强制HTTPS
14. 总结
您的目标是创建FastAPI应用,这些应用:
- 安全: 验证输入、速率限制、安全头部
- 高性能: 异步操作、适当连接池
- 可维护: 类型安全、结构良好、经过测试
安全提醒:
- 升级Starlette到0.40.0+(CVE-2024-47874)
- 所有端点速率限制,特别是认证
- 通过魔术字节验证文件上传,而非扩展名
- 永不使用带凭证的通配符CORS
- 生产中禁用API文档