name: python-type-safety description: 使用类型提示、泛型、协议和严格类型检查实现Python类型安全。适用于添加类型注释、实现泛型类、定义结构接口或配置mypy/pyright。
Python 类型安全
利用Python的类型系统在静态分析时捕获错误。类型注释作为强制文档,工具自动验证。
何时使用此技能
- 向现有代码添加类型提示
- 创建通用、可重用的类
- 使用协议定义结构接口
- 配置mypy或pyright进行严格检查
- 理解类型缩小和守卫
- 构建类型安全的API和库
核心概念
1. 类型注释
为函数参数、返回值和变量声明期望的类型。
2. 泛型
编写可重用代码,跨不同类型保留类型信息。
3. 协议
无需继承定义结构接口(具有类型安全的鸭子类型)。
4. 类型缩小
使用守卫和条件语句在代码块内缩小类型。
快速启动
def get_user(user_id: str) -> User | None:
"""返回类型使'可能不存在'明确。"""
...
# 类型检查器强制执行处理None情况
user = get_user("123")
if user is None:
raise UserNotFoundError("123")
print(user.name) # 类型检查器知道此处user是User类型
基础模式
模式1:注释所有公共签名
每个公共函数、方法和类都应具有类型注释。
def get_user(user_id: str) -> User:
"""通过ID检索用户。"""
...
def process_batch(
items: list[Item],
max_workers: int = 4,
) -> BatchResult[ProcessedItem]:
"""并发处理项目。"""
...
class UserRepository:
def __init__(self, db: Database) -> None:
self._db = db
async def find_by_id(self, user_id: str) -> User | None:
"""如果找到则返回User,否则返回None。"""
...
async def find_by_email(self, email: str) -> User | None:
...
async def save(self, user: User) -> User:
"""保存并返回带有生成ID的用户。"""
...
在CI中使用mypy --strict或pyright及早捕获类型错误。对于现有项目,使用每模块覆盖逐步启用严格模式。
模式2:使用现代联合语法
Python 3.10+提供了更简洁的联合语法。
# 首选(3.10+)
def find_user(user_id: str) -> User | None:
...
def parse_value(v: str) -> int | float | str:
...
# 旧风格(仍然有效,适用于3.9)
from typing import Optional, Union
def find_user(user_id: str) -> Optional[User]:
...
模式3:使用守卫进行类型缩小
使用条件语句为类型检查器缩小类型。
def process_user(user_id: str) -> UserData:
user = find_user(user_id)
if user is None:
raise UserNotFoundError(f"用户 {user_id} 未找到")
# 类型检查器知道此处user是User类型,而非User | None
return UserData(
name=user.name,
email=user.email,
)
def process_items(items: list[Item | None]) -> list[ProcessedItem]:
# 过滤并缩小类型
valid_items = [item for item in items if item is not None]
# valid_items现在为list[Item]
return [process(item) for item in valid_items]
模式4:泛型类
创建类型安全的可重用容器。
from typing import TypeVar, Generic
T = TypeVar("T")
E = TypeVar("E", bound=Exception)
class Result(Generic[T, E]):
"""表示成功值或错误。"""
def __init__(
self,
value: T | None = None,
error: E | None = None,
) -> None:
if (value is None) == (error is None):
raise ValueError("必须设置值或错误之一")
self._value = value
self._error = error
@property
def is_success(self) -> bool:
return self._error is None
@property
def is_failure(self) -> bool:
return self._error is not None
def unwrap(self) -> T:
"""获取值或引发错误。"""
if self._error is not None:
raise self._error
return self._value # type: ignore[return-value]
def unwrap_or(self, default: T) -> T:
"""获取值或返回默认值。"""
if self._error is not None:
return default
return self._value # type: ignore[return-value]
# 使用保留类型
def parse_config(path: str) -> Result[Config, ConfigError]:
try:
return Result(value=Config.from_file(path))
except ConfigError as e:
return Result(error=e)
result = parse_config("config.yaml")
if result.is_success:
config = result.unwrap() # 类型: Config
高级模式
模式5:泛型仓库
创建类型安全的数据访问模式。
from typing import TypeVar, Generic
from abc import ABC, abstractmethod
T = TypeVar("T")
ID = TypeVar("ID")
class Repository(ABC, Generic[T, ID]):
"""泛型仓库接口。"""
@abstractmethod
async def get(self, id: ID) -> T | None:
"""通过ID获取实体。"""
...
@abstractmethod
async def save(self, entity: T) -> T:
"""保存并返回实体。"""
...
@abstractmethod
async def delete(self, id: ID) -> bool:
"""删除实体,如果存在则返回True。"""
...
class UserRepository(Repository[User, str]):
"""具有字符串ID的用户具体仓库。"""
async def get(self, id: str) -> User | None:
row = await self._db.fetchrow(
"SELECT * FROM users WHERE id = $1", id
)
return User(**row) if row else None
async def save(self, entity: User) -> User:
...
async def delete(self, id: str) -> bool:
...
模式6:带界限的TypeVar
将泛型参数限制为特定类型。
from typing import TypeVar
from pydantic import BaseModel
ModelT = TypeVar("ModelT", bound=BaseModel)
def validate_and_create(model_cls: type[ModelT], data: dict) -> ModelT:
"""从字典创建验证后的Pydantic模型。"""
return model_cls.model_validate(data)
# 适用于任何BaseModel子类
class User(BaseModel):
name: str
email: str
user = validate_and_create(User, {"name": "Alice", "email": "a@b.com"})
# user类型化为User
# 类型错误: str不是BaseModel子类
result = validate_and_create(str, {"name": "Alice"}) # 错误!
模式7:协议用于结构类型
无需继承定义接口。
from typing import Protocol, runtime_checkable
@runtime_checkable
class Serializable(Protocol):
"""任何可以序列化到/从字典的类。"""
def to_dict(self) -> dict:
...
@classmethod
def from_dict(cls, data: dict) -> "Serializable":
...
# User满足Serializable协议而无需继承
class User:
def __init__(self, id: str, name: str) -> None:
self.id = id
self.name = name
def to_dict(self) -> dict:
return {"id": self.id, "name": self.name}
@classmethod
def from_dict(cls, data: dict) -> "User":
return cls(id=data["id"], name=data["name"])
def serialize(obj: Serializable) -> str:
"""适用于任何Serializable对象。"""
return json.dumps(obj.to_dict())
# 有效 - User匹配协议
serialize(User("1", "Alice"))
# 运行时检查使用@runtime_checkable
isinstance(User("1", "Alice"), Serializable) # True
模式8:常见协议模式
定义可重用的结构接口。
from typing import Protocol
class Closeable(Protocol):
"""可以关闭的资源。"""
def close(self) -> None: ...
class AsyncCloseable(Protocol):
"""可以关闭的异步资源。"""
async def close(self) -> None: ...
class Readable(Protocol):
"""可以从中读取的对象。"""
def read(self, n: int = -1) -> bytes: ...
class HasId(Protocol):
"""具有ID属性的对象。"""
@property
def id(self) -> str: ...
class Comparable(Protocol):
"""支持比较的对象。"""
def __lt__(self, other: "Comparable") -> bool: ...
def __le__(self, other: "Comparable") -> bool: ...
模式9:类型别名
创建有意义的类型名称。
注意: type语句在Python 3.10中引入用于简单别名。泛型类型语句需要Python 3.12+。
# Python 3.10+ type语句用于简单别名
type UserId = str
type UserDict = dict[str, Any]
# Python 3.12+ type语句带泛型
type Handler[T] = Callable[[Request], T]
type AsyncHandler[T] = Callable[[Request], Awaitable[T]]
# Python 3.9-3.11风格(需要更广泛的兼容性)
from typing import TypeAlias
from collections.abc import Callable, Awaitable
UserId: TypeAlias = str
Handler: TypeAlias = Callable[[Request], Response]
# 使用
def register_handler(path: str, handler: Handler[Response]) -> None:
...
模式10:可调用类型
为函数参数和回调添加类型。
from collections.abc import Callable, Awaitable
# 同步回调
ProgressCallback = Callable[[int, int], None] # (当前, 总计)
# 异步回调
AsyncHandler = Callable[[Request], Awaitable[Response]]
# 带命名参数(使用协议)
class OnProgress(Protocol):
def __call__(
self,
current: int,
total: int,
*,
message: str = "",
) -> None: ...
def process_items(
items: list[Item],
on_progress: ProgressCallback | None = None,
) -> list[Result]:
for i, item in enumerate(items):
if on_progress:
on_progress(i, len(items))
...
配置
严格模式清单
对于mypy --strict合规:
# pyproject.toml
[tool.mypy]
python_version = "3.12"
strict = true
warn_return_any = true
warn_unused_ignores = true
disallow_untyped_defs = true
disallow_incomplete_defs = true
no_implicit_optional = true
增量采用目标:
- 所有函数参数已注释
- 所有返回类型已注释
- 类属性已注释
- 最小化
Any使用(适用于真正的动态数据) - 泛型集合使用类型参数(
list[str]而非list)
对于现有代码库,使用# mypy: strict或配置每模块覆盖在pyproject.toml中启用严格模式。
最佳实践总结
- 注释所有公共API - 函数、方法、类属性
- 使用
T | None- 现代联合语法优于Optional[T] - 运行严格类型检查 - CI中的
mypy --strict - 使用泛型 - 在可重用代码中保留类型信息
- 定义协议 - 结构类型用于接口
- 缩小类型 - 使用守卫帮助类型检查器
- 绑定类型变量 - 将泛型限制为有意义的类型
- 创建类型别名 - 复杂类型的有意义名称
- 最小化
Any- 使用特定类型或泛型。Any适用于真正的动态数据或与未类型化的第三方代码交互时 - 用类型文档化 - 类型是可强制执行的文档