名称:python-测试-模式 描述:使用pytest、fixtures、mocking和测试驱动开发实现全面的测试策略。适用于编写Python测试、设置测试套件或实施测试最佳实践。
Python 测试模式
使用pytest、fixtures、mocking、参数化测试和测试驱动开发实践在Python中实现稳健测试策略的全面指南。
何时使用此技能
- 为Python代码编写单元测试
- 设置测试套件和测试基础设施
- 实施测试驱动开发(TDD)
- 为API和服务创建集成测试
- Mocking外部依赖和服务
- 测试异步代码和并发操作
- 在CI/CD中设置持续测试
- 实现基于属性的测试
- 测试数据库操作
- 调试失败的测试
核心概念
1. 测试类型
- 单元测试:隔离测试单个函数/类
- 集成测试:测试组件之间的交互
- 功能测试:端到端测试完整功能
- 性能测试:测量速度和资源使用
2. 测试结构(AAA模式)
- 安排:设置测试数据和前提条件
- 执行:执行被测代码
- 断言:验证结果
3. 测试覆盖率
- 测量代码被测试覆盖的部分
- 识别未测试的代码路径
- 追求有意义的覆盖率,而非仅高百分比
4. 测试隔离
- 测试应独立
- 测试间无共享状态
- 每个测试应自行清理
快速开始
# test_example.py
def add(a, b):
return a + b
def test_add():
"""基本测试示例。"""
result = add(2, 3)
assert result == 5
def test_add_negative():
"""测试负数。"""
assert add(-1, 1) == 0
# 运行:pytest test_example.py
基础模式
模式 1:基本pytest测试
# test_calculator.py
import pytest
class Calculator:
"""用于测试的简单计算器。"""
def add(self, a: float, b: float) -> float:
return a + b
def subtract(self, a: float, b: float) -> float:
return a - b
def multiply(self, a: float, b: float) -> float:
return a * b
def divide(self, a: float, b: float) -> float:
if b == 0:
raise ValueError("不能除以零")
return a / b
def test_addition():
"""测试加法。"""
calc = Calculator()
assert calc.add(2, 3) == 5
assert calc.add(-1, 1) == 0
assert calc.add(0, 0) == 0
def test_subtraction():
"""测试减法。"""
calc = Calculator()
assert calc.subtract(5, 3) == 2
assert calc.subtract(0, 5) == -5
def test_multiplication():
"""测试乘法。"""
calc = Calculator()
assert calc.multiply(3, 4) == 12
assert calc.multiply(0, 5) == 0
def test_division():
"""测试除法。"""
calc = Calculator()
assert calc.divide(6, 3) == 2
assert calc.divide(5, 2) == 2.5
def test_division_by_zero():
"""测试除以零引发错误。"""
calc = Calculator()
with pytest.raises(ValueError, match="不能除以零"):
calc.divide(5, 0)
模式 2:用于设置和清理的Fixtures
# test_database.py
import pytest
from typing import Generator
class Database:
"""简单数据库类。"""
def __init__(self, connection_string: str):
self.connection_string = connection_string
self.connected = False
def connect(self):
"""连接数据库。"""
self.connected = True
def disconnect(self):
"""断开数据库连接。"""
self.connected = False
def query(self, sql: str) -> list:
"""执行查询。"""
if not self.connected:
raise RuntimeError("未连接")
return [{"id": 1, "name": "测试"}]
@pytest.fixture
def db() -> Generator[Database, None, None]:
"""提供已连接数据库的Fixture。"""
# 设置
database = Database("sqlite:///:memory:")
database.connect()
# 提供给测试
yield database
# 清理
database.disconnect()
def test_database_query(db):
"""使用fixture测试数据库查询。"""
results = db.query("SELECT * FROM users")
assert len(results) == 1
assert results[0]["name"] == "测试"
@pytest.fixture(scope="session")
def app_config():
"""会话范围的fixture - 每个测试会话创建一次。"""
return {
"database_url": "postgresql://localhost/test",
"api_key": "test-key",
"debug": True
}
@pytest.fixture(scope="module")
def api_client(app_config):
"""模块范围的fixture - 每个测试模块创建一次。"""
# 设置昂贵资源
client = {"config": app_config, "session": "active"}
yield client
# 清理
client["session"] = "closed"
def test_api_client(api_client):
"""测试使用api客户端fixture。"""
assert api_client["session"] == "active"
assert api_client["config"]["debug"] is True
模式 3:参数化测试
# test_validation.py
import pytest
def is_valid_email(email: str) -> bool:
"""检查邮箱是否有效。"""
return "@" in email and "." in email.split("@")[1]
@pytest.mark.parametrize("email,expected", [
("user@example.com", True),
("test.user@domain.co.uk", True),
("invalid.email", False),
("@example.com", False),
("user@domain", False),
("", False),
])
def test_email_validation(email, expected):
"""使用各种输入测试邮箱验证。"""
assert is_valid_email(email) == expected
@pytest.mark.parametrize("a,b,expected", [
(2, 3, 5),
(0, 0, 0),
(-1, 1, 0),
(100, 200, 300),
(-5, -5, -10),
])
def test_addition_parameterized(a, b, expected):
"""使用多组参数测试加法。"""
from test_calculator import Calculator
calc = Calculator()
assert calc.add(a, b) == expected
# 使用pytest.param处理特殊情况
@pytest.mark.parametrize("value,expected", [
pytest.param(1, True, id="positive"),
pytest.param(0, False, id="zero"),
pytest.param(-1, False, id="negative"),
])
def test_is_positive(value, expected):
"""使用自定义测试ID测试。"""
assert (value > 0) == expected
模式 4:使用unittest.mock进行Mocking
# test_api_client.py
import pytest
from unittest.mock import Mock, patch, MagicMock
import requests
class APIClient:
"""简单API客户端。"""
def __init__(self, base_url: str):
self.base_url = base_url
def get_user(self, user_id: int) -> dict:
"""从API获取用户。"""
response = requests.get(f"{self.base_url}/users/{user_id}")
response.raise_for_status()
return response.json()
def create_user(self, data: dict) -> dict:
"""创建新用户。"""
response = requests.post(f"{self.base_url}/users", json=data)
response.raise_for_status()
return response.json()
def test_get_user_success():
"""使用mock测试成功的API调用。"""
client = APIClient("https://api.example.com")
mock_response = Mock()
mock_response.json.return_value = {"id": 1, "name": "John Doe"}
mock_response.raise_for_status.return_value = None
with patch("requests.get", return_value=mock_response) as mock_get:
user = client.get_user(1)
assert user["id"] == 1
assert user["name"] == "John Doe"
mock_get.assert_called_once_with("https://api.example.com/users/1")
def test_get_user_not_found():
"""测试404错误的API调用。"""
client = APIClient("https://api.example.com")
mock_response = Mock()
mock_response.raise_for_status.side_effect = requests.HTTPError("404 Not Found")
with patch("requests.get", return_value=mock_response):
with pytest.raises(requests.HTTPError):
client.get_user(999)
@patch("requests.post")
def test_create_user(mock_post):
"""使用装饰器语法测试用户创建。"""
client = APIClient("https://api.example.com")
mock_post.return_value.json.return_value = {"id": 2, "name": "Jane Doe"}
mock_post.return_value.raise_for_status.return_value = None
user_data = {"name": "Jane Doe", "email": "jane@example.com"}
result = client.create_user(user_data)
assert result["id"] == 2
mock_post.assert_called_once()
call_args = mock_post.call_args
assert call_args.kwargs["json"] == user_data
模式 5:测试异常
# test_exceptions.py
import pytest
def divide(a: float, b: float) -> float:
"""将a除以b。"""
if b == 0:
raise ZeroDivisionError("除以零")
if not isinstance(a, (int, float)) or not isinstance(b, (int, float)):
raise TypeError("参数必须是数字")
return a / b
def test_zero_division():
"""测试除以零引发异常。"""
with pytest.raises(ZeroDivisionError):
divide(10, 0)
def test_zero_division_with_message():
"""测试异常消息。"""
with pytest.raises(ZeroDivisionError, match="除以零"):
divide(5, 0)
def test_type_error():
"""测试类型错误异常。"""
with pytest.raises(TypeError, match="必须是数字"):
divide("10", 5)
def test_exception_info():
"""测试访问异常信息。"""
with pytest.raises(ValueError) as exc_info:
int("不是数字")
assert "无效字面量" in str(exc_info.value)
高级模式
模式 6:测试异步代码
# test_async.py
import pytest
import asyncio
async def fetch_data(url: str) -> dict:
"""异步获取数据。"""
await asyncio.sleep(0.1)
return {"url": url, "data": "result"}
@pytest.mark.asyncio
async def test_fetch_data():
"""测试异步函数。"""
result = await fetch_data("https://api.example.com")
assert result["url"] == "https://api.example.com"
assert "data" in result
@pytest.mark.asyncio
async def test_concurrent_fetches():
"""测试并发异步操作。"""
urls = ["url1", "url2", "url3"]
tasks = [fetch_data(url) for url in urls]
results = await asyncio.gather(*tasks)
assert len(results) == 3
assert all("data" in r for r in results)
@pytest.fixture
async def async_client():
"""异步fixture。"""
client = {"connected": True}
yield client
client["connected"] = False
@pytest.mark.asyncio
async def test_with_async_fixture(async_client):
"""测试使用异步fixture。"""
assert async_client["connected"] is True
模式 7:用于测试的Monkeypatch
# test_environment.py
import os
import pytest
def get_database_url() -> str:
"""从环境获取数据库URL。"""
return os.environ.get("DATABASE_URL", "sqlite:///:memory:")
def test_database_url_default():
"""测试默认数据库URL。"""
# 如果设置了实际环境变量,将使用它
url = get_database_url()
assert url
def test_database_url_custom(monkeypatch):
"""使用monkeypatch测试自定义数据库URL。"""
monkeypatch.setenv("DATABASE_URL", "postgresql://localhost/test")
assert get_database_url() == "postgresql://localhost/test"
def test_database_url_not_set(monkeypatch):
"""测试环境变量未设置时。"""
monkeypatch.delenv("DATABASE_URL", raising=False)
assert get_database_url() == "sqlite:///:memory:"
class Config:
"""配置类。"""
def __init__(self):
self.api_key = "production-key"
def get_api_key(self):
return self.api_key
def test_monkeypatch_attribute(monkeypatch):
"""测试monkeypatching对象属性。"""
config = Config()
monkeypatch.setattr(config, "api_key", "test-key")
assert config.get_api_key() == "test-key"
模式 8:临时文件和目录
# test_file_operations.py
import pytest
from pathlib import Path
def save_data(filepath: Path, data: str):
"""保存数据到文件。"""
filepath.write_text(data)
def load_data(filepath: Path) -> str:
"""从文件加载数据。"""
return filepath.read_text()
def test_file_operations(tmp_path):
"""使用临时目录测试文件操作。"""
# tmp_path 是一个 pathlib.Path 对象
test_file = tmp_path / "test_data.txt"
# 保存数据
save_data(test_file, "Hello, World!")
# 验证文件存在
assert test_file.exists()
# 加载并验证数据
data = load_data(test_file)
assert data == "Hello, World!"
def test_multiple_files(tmp_path):
"""测试多个临时文件。"""
files = {
"file1.txt": "Content 1",
"file2.txt": "Content 2",
"file3.txt": "Content 3"
}
for filename, content in files.items():
filepath = tmp_path / filename
save_data(filepath, content)
# 验证所有文件已创建
assert len(list(tmp_path.iterdir())) == 3
# 验证内容
for filename, expected_content in files.items():
filepath = tmp_path / filename
assert load_data(filepath) == expected_content
模式 9:自定义Fixtures和Conftest
# conftest.py
"""所有测试的共享fixtures。"""
import pytest
@pytest.fixture(scope="session")
def database_url():
"""为所有测试提供数据库URL。"""
return "postgresql://localhost/test_db"
@pytest.fixture(autouse=True)
def reset_database(database_url):
"""在每个测试前运行的自动使用fixture。"""
# 设置:清空数据库
print(f"清空数据库:{database_url}")
yield
# 清理:清理
print("测试完成")
@pytest.fixture
def sample_user():
"""提供样本用户数据。"""
return {
"id": 1,
"name": "测试用户",
"email": "test@example.com"
}
@pytest.fixture
def sample_users():
"""提供样本用户列表。"""
return [
{"id": 1, "name": "用户 1"},
{"id": 2, "name": "用户 2"},
{"id": 3, "name": "用户 3"},
]
# 参数化fixture
@pytest.fixture(params=["sqlite", "postgresql", "mysql"])
def db_backend(request):
"""用不同数据库后端运行测试的fixture。"""
return request.param
def test_with_db_backend(db_backend):
"""这个测试将用不同后端运行3次。"""
print(f"用 {db_backend} 测试")
assert db_backend in ["sqlite", "postgresql", "mysql"]
模式 10:基于属性的测试
# test_properties.py
from hypothesis import given, strategies as st
import pytest
def reverse_string(s: str) -> str:
"""反转字符串。"""
return s[::-1]
@given(st.text())
def test_reverse_twice_is_original(s):
"""属性:反转两次返回原字符串。"""
assert reverse_string(reverse_string(s)) == s
@given(st.text())
def test_reverse_length(s):
"""属性:反转字符串长度相同。"""
assert len(reverse_string(s)) == len(s)
@given(st.integers(), st.integers())
def test_addition_commutative(a, b):
"""属性:加法交换律。"""
assert a + b == b + a
@given(st.lists(st.integers()))
def test_sorted_list_properties(lst):
"""属性:排序列表有序。"""
sorted_lst = sorted(lst)
# 相同长度
assert len(sorted_lst) == len(lst)
# 所有元素存在
assert set(sorted_lst) == set(lst)
# 有序
for i in range(len(sorted_lst) - 1):
assert sorted_lst[i] <= sorted_lst[i + 1]
测试最佳实践
测试组织
# tests/
# __init__.py
# conftest.py # 共享fixtures
# test_unit/ # 单元测试
# test_models.py
# test_utils.py
# test_integration/ # 集成测试
# test_api.py
# test_database.py
# test_e2e/ # 端到端测试
# test_workflows.py
测试命名
# 好的测试名称
def test_user_creation_with_valid_data():
"""清晰名称描述被测试的内容。"""
pass
def test_login_fails_with_invalid_password():
"""名称描述预期行为。"""
pass
def test_api_returns_404_for_missing_resource():
"""具体说明输入和预期结果。"""
pass
# 差的测试名称
def test_1(): # 不具描述性
pass
def test_user(): # 太模糊
pass
def test_function(): # 未解释测试内容
pass
测试标记
# test_markers.py
import pytest
@pytest.mark.slow
def test_slow_operation():
"""标记慢测试。"""
import time
time.sleep(2)
@pytest.mark.integration
def test_database_integration():
"""标记集成测试。"""
pass
@pytest.mark.skip(reason="功能尚未实现")
def test_future_feature():
"""暂时跳过测试。"""
pass
@pytest.mark.skipif(os.name == "nt", reason="仅Unix测试")
def test_unix_specific():
"""条件跳过。"""
pass
@pytest.mark.xfail(reason="已知错误 #123")
def test_known_bug():
"""标记预期失败。"""
assert False
# 运行:
# pytest -m slow # 仅运行慢测试
# pytest -m "not slow" # 跳过慢测试
# pytest -m integration # 运行集成测试
覆盖率报告
# 安装覆盖率
pip install pytest-cov
# 运行测试带覆盖率
pytest --cov=myapp tests/
# 生成HTML报告
pytest --cov=myapp --cov-report=html tests/
# 如果覆盖率低于阈值则失败
pytest --cov=myapp --cov-fail-under=80 tests/
# 显示缺失行
pytest --cov=myapp --cov-report=term-missing tests/
测试数据库代码
# test_database_models.py
import pytest
from sqlalchemy import create_engine, Column, Integer, String
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, Session
Base = declarative_base()
class User(Base):
"""用户模型。"""
__tablename__ = "users"
id = Column(Integer, primary_key=True)
name = Column(String(50))
email = Column(String(100), unique=True)
@pytest.fixture(scope="function")
def db_session() -> Session:
"""为测试创建内存数据库。"""
engine = create_engine("sqlite:///:memory:")
Base.metadata.create_all(engine)
SessionLocal = sessionmaker(bind=engine)
session = SessionLocal()
yield session
session.close()
def test_create_user(db_session):
"""测试创建用户。"""
user = User(name="测试用户", email="test@example.com")
db_session.add(user)
db_session.commit()
assert user.id is not None
assert user.name == "测试用户"
def test_query_user(db_session):
"""测试查询用户。"""
user1 = User(name="用户 1", email="user1@example.com")
user2 = User(name="用户 2", email="user2@example.com")
db_session.add_all([user1, user2])
db_session.commit()
users = db_session.query(User).all()
assert len(users) == 2
def test_unique_email_constraint(db_session):
"""测试唯一邮箱约束。"""
from sqlalchemy.exc import IntegrityError
user1 = User(name="用户 1", email="same@example.com")
user2 = User(name="用户 2", email="same@example.com")
db_session.add(user1)
db_session.commit()
db_session.add(user2)
with pytest.raises(IntegrityError):
db_session.commit()
CI/CD 集成
# .github/workflows/test.yml
名称:测试
触发事件:[push, pull_request]
任务:
测试:
运行环境:ubuntu-latest
策略:
矩阵:
python-版本:["3.9", "3.10", "3.11", "3.12"]
步骤:
- 使用:actions/checkout@v3
- 名称:设置Python
使用:actions/setup-python@v4
参数:
python-版本:${{ matrix.python-版本 }}
- 名称:安装依赖
运行: |
pip install -e ".[dev]"
pip install pytest pytest-cov
- 名称:运行测试
运行: |
pytest --cov=myapp --cov-report=xml
- 名称:上传覆盖率
使用:codecov/codecov-action@v3
参数:
文件:./coverage.xml
配置文件
# pytest.ini
[pytest]
测试路径 = tests
python文件 = test_*.py
python类 = Test*
python函数 = test_*
附加选项 =
-v
--strict-markers
--tb=short
--cov=myapp
--cov-report=term-missing
标记 =
slow: 标记慢测试
integration: 标记集成测试
unit: 标记单元测试
e2e: 标记端到端测试
# pyproject.toml
[tool.pytest.ini_options]
测试路径 = ["tests"]
python文件 = ["test_*.py"]
附加选项 = [
"-v",
"--cov=myapp",
"--cov-report=term-missing",
]
[tool.coverage.run]
源 = ["myapp"]
省略 = ["*/tests/*", "*/migrations/*"]
[tool.coverage.report]
排除行 = [
"pragma: no cover",
"def __repr__",
"raise AssertionError",
"raise NotImplementedError",
]
资源
- pytest文档:https://docs.pytest.org/
- unittest.mock:https://docs.python.org/3/library/unittest.mock.html
- hypothesis:基于属性的测试
- pytest-asyncio:测试异步代码
- pytest-cov:覆盖率报告
- pytest-mock:pytest的mock包装器
最佳实践总结
- 先写测试(TDD)或与代码一起写
- 每个测试一个断言(可能时)
- 使用描述性测试名称解释行为
- 保持测试独立和隔离
- 使用fixtures进行设置和清理
- 适当mock外部依赖
- 参数化测试减少重复
- 测试边缘情况和错误条件
- 测量覆盖率但关注质量
- 在CI/CD中运行测试每次提交