name: pytest description: Python测试框架,用于编写简单、可扩展且功能强大的测试 when_to_use: 当您需要使用夹具、参数化或异步测试来编写、运行或组织Python测试时
Pytest 测试框架
Pytest 是一个成熟的 Python 测试框架,它使得编写小型测试变得容易,同时能够扩展以支持复杂的功能测试。
快速开始
基本测试结构
# test_example.py
def test_addition():
assert 2 + 2 == 4
def test_string_operations():
assert "hello".upper() == "HELLO"
assert "world" in "hello world"
运行测试
# 运行所有测试
uv run pytest
# 以详细输出运行
uv run pytest -v
# 运行特定测试文件
uv run pytest test_example.py
# 运行特定测试函数
uv run pytest test_example.py::test_addition
常用模式
夹具
基本夹具定义:
import pytest
@pytest.fixture
def sample_data():
return {"name": "Alice", "age": 30}
def test_user_data(sample_data):
assert sample_data["name"] == "Alice"
assert sample_data["age"] == 30
带设置和清理的夹具:
@pytest.fixture
def database_connection():
# 设置
conn = create_database_connection()
yield conn
# 清理
conn.close()
def test_database_query(database_connection):
result = database_connection.query("SELECT * FROM users")
assert len(result) > 0
夹具作用域:
@pytest.fixture(scope="function") # 默认 - 每个测试创建一次
def temp_file():
pass
@pytest.fixture(scope="module") # 每个模块创建一次
def module_resource():
pass
@pytest.fixture(scope="session") # 每个测试会话创建一次
def session_resource():
pass
参数化
基本参数化:
@pytest.mark.parametrize("input,expected", [
("3+5", 8),
("2+4", 6),
("6*9", 54),
])
def test_eval(input, expected):
assert eval(input) == expected
参数化夹具:
@pytest.fixture(params=["mysql", "postgresql", "sqlite"])
def database(request):
if request.param == "mysql":
return MySQLConnection()
elif request.param == "postgresql":
return PostgreSQLConnection()
else:
return SQLiteConnection()
def test_database_operations(database):
# 测试运行3次,每种数据库类型一次
result = database.execute("SELECT 1")
assert result == 1
堆叠参数化进行组合测试:
@pytest.mark.parametrize("x", [0, 1])
@pytest.mark.parametrize("y", [2, 3])
def test_combinations(x, y):
# 运行4次: (0,2), (0,3), (1,2), (1,3)
assert x + y > 1
异步测试
基本异步测试:
import pytest
@pytest.mark.asyncio
async def test_async_function():
result = await async_operation()
assert result is not None
异步夹具:
@pytest.fixture
async def async_client():
client = AsyncClient()
await client.connect()
yield client
await client.disconnect()
@pytest.mark.asyncio
async def test_async_api(async_client):
response = await async_client.get("/api/data")
assert response.status_code == 200
测试组织
使用 conftest.py 共享夹具:
# conftest.py
@pytest.fixture
def authenticated_client():
client = create_test_client()
client.login("testuser", "password")
return client
@pytest.fixture(scope="session")
def test_database():
db = create_test_database()
yield db
db.cleanup()
测试类:
class TestUserAPI:
def test_create_user(self, authenticated_client):
response = authenticated_client.post("/users", json={"name": "John"})
assert response.status_code == 201
def test_get_user(self, authenticated_client):
user_id = create_test_user()
response = authenticated_client.get(f"/users/{user_id}")
assert response.status_code == 200
模拟和打补丁
使用 monkeypatch 夹具:
def test_environment_variable(monkeypatch):
monkeypatch.setenv("API_KEY", "test-key")
assert get_api_key() == "test-key"
def test_file_operations(monkeypatch, tmp_path):
test_file = tmp_path / "test.txt"
test_file.write_text("test content")
monkeypatch.setattr("module.FILE_PATH", str(test_file))
assert read_file_content() == "test content"
标记器和选择
自定义标记器:
# pytest.ini
[tool:pytest]
markers =
slow: 标记测试为慢速测试
integration: 标记测试为集成测试
unit: 标记测试为单元测试
# test_file.py
@pytest.mark.slow
def test_expensive_operation():
pass
@pytest.mark.integration
def test_database_integration():
pass
按标记器运行测试:
# 仅运行单元测试
uv run pytest -m unit
# 跳过慢速测试
uv run pytest -m "not slow"
# 运行集成或单元测试
uv run pytest -m "integration or unit"
实用代码片段
API 测试
import pytest
from fastapi.testclient import TestClient
from myapp import app
@pytest.fixture
def client():
return TestClient(app)
@pytest.fixture
def test_user():
return {"username": "testuser", "email": "test@example.com"}
def test_create_user(client, test_user):
response = client.post("/users/", json=test_user)
assert response.status_code == 201
assert response.json()["username"] == test_user["username"]
def test_get_user(client, test_user):
# 首先创建用户
create_response = client.post("/users/", json=test_user)
user_id = create_response.json()["id"]
# 获取用户
response = client.get(f"/users/{user_id}")
assert response.status_code == 200
assert response.json()["email"] == test_user["email"]
数据库测试
import pytest
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
@pytest.fixture(scope="function")
def test_db():
engine = create_engine("sqlite:///:memory:")
Session = sessionmaker(bind=engine)
Base.metadata.create_all(engine)
session = Session()
yield session
session.close()
def test_user_creation(test_db):
user = User(name="John", email="john@example.com")
test_db.add(user)
test_db.commit()
retrieved_user = test_db.query(User).filter_by(name="John").first()
assert retrieved_user.email == "john@example.com"
错误处理测试
def test_invalid_input_raises_error():
with pytest.raises(ValueError, match="Invalid input"):
process_input("invalid")
def test_file_not_found():
with pytest.raises(FileNotFoundError):
read_nonexistent_file()
def test_custom_exception():
with pytest.raises(CustomAPIError) as exc_info:
call_api_endpoint()
assert exc_info.value.status_code == 404
assert "not found" in str(exc_info.value)
要求
- Python 3.7+
- pytest (
uv add --dev pytest) - 对于异步测试: pytest-asyncio (
uv add --dev pytest-asyncio) - 对于API测试: Web框架测试客户端 (例如,
uv add --dev httpx用于异步HTTP测试)