名称: 缓存 描述: 全面的缓存策略和模式,用于性能优化。适用于实现缓存层、缓存失效、TTL策略或分布式缓存。涵盖Redis/Memcached模式、CDN缓存、数据库查询缓存、机器学习模型缓存和淘汰策略。触发器: cache, caching, Redis, Memcached, CDN, TTL, invalidation, eviction, LRU, LFU, FIFO, write-through, write-behind, cache-aside, read-through, cache stampede, distributed cache, local cache, memoization, query cache, result cache, edge cache, browser cache, HTTP cache。
缓存
概述
缓存通过将频繁访问的数据存储在更接近消费者的位置来提高应用程序性能。此技能涵盖缓存策略(旁路、直写、回写)、失效模式、TTL管理、Redis/Memcached使用、击穿预防和分布式缓存。
指令
1. 缓存策略
缓存旁路(惰性加载)
from typing import TypeVar, Optional, Callable
import json
T = TypeVar('T')
class CacheAside:
"""应用程序显式管理缓存。"""
def __init__(self, cache_client, default_ttl: int = 3600):
self.cache = cache_client
self.default_ttl = default_ttl
async def get_or_load(
self,
key: str,
loader: Callable[[], T],
ttl: Optional[int] = None
) -> T:
# 首先尝试缓存
cached = await self.cache.get(key)
if cached is not None:
return json.loads(cached)
# 从源加载
value = await loader()
# 存储在缓存中
await self.cache.setex(
key,
ttl or self.default_ttl,
json.dumps(value)
)
return value
async def invalidate(self, key: str):
await self.cache.delete(key)
# 使用
cache = CacheAside(redis_client)
async def get_user(user_id: str) -> User:
return await cache.get_or_load(
f"user:{user_id}",
lambda: database.get_user(user_id),
ttl=300
)
直写缓存
class WriteThrough:
"""写操作同时更新缓存和数据库。"""
def __init__(self, cache_client, database, default_ttl: int = 3600):
self.cache = cache_client
self.database = database
self.default_ttl = default_ttl
async def write(self, key: str, value: any, ttl: Optional[int] = None):
# 首先写数据库
await self.database.save(key, value)
# 然后更新缓存
await self.cache.setex(
key,
ttl or self.default_ttl,
json.dumps(value)
)
async def read(self, key: str) -> Optional[any]:
cached = await self.cache.get(key)
if cached:
return json.loads(cached)
# 回退到数据库
value = await self.database.get(key)
if value:
await self.cache.setex(key, self.default_ttl, json.dumps(value))
return value
回写(写回)缓存
import asyncio
from collections import defaultdict
from dataclasses import dataclass
from typing import Dict, Any
import time
@dataclass
class PendingWrite:
key: str
value: Any
timestamp: float
class WriteBehind:
"""写操作立即更新缓存,异步更新数据库。"""
def __init__(
self,
cache_client,
database,
flush_interval: float = 5.0,
batch_size: int = 100
):
self.cache = cache_client
self.database = database
self.flush_interval = flush_interval
self.batch_size = batch_size
self.pending_writes: Dict[str, PendingWrite] = {}
self._lock = asyncio.Lock()
self._flush_task = None
async def start(self):
self._flush_task = asyncio.create_task(self._flush_loop())
async def stop(self):
if self._flush_task:
self._flush_task.cancel()
await self._flush_pending()
async def write(self, key: str, value: Any):
# 立即写缓存
await self.cache.set(key, json.dumps(value))
# 排队数据库写
async with self._lock:
self.pending_writes[key] = PendingWrite(
key=key,
value=value,
timestamp=time.time()
)
if len(self.pending_writes) >= self.batch_size:
await self._flush_pending()
async def _flush_loop(self):
while True:
await asyncio.sleep(self.flush_interval)
await self._flush_pending()
async def _flush_pending(self):
async with self._lock:
if not self.pending_writes:
return
writes = list(self.pending_writes.values())
self.pending_writes.clear()
# 批量写数据库
await self.database.batch_save(
[(w.key, w.value) for w in writes]
)
直读缓存
class ReadThrough:
"""缓存错过时自动从数据库加载。"""
def __init__(self, cache_client, loader, default_ttl: int = 3600):
self.cache = cache_client
self.loader = loader
self.default_ttl = default_ttl
async def get(self, key: str) -> Optional[Any]:
# 检查缓存
cached = await self.cache.get(key)
if cached:
return json.loads(cached)
# 错过时自动加载
value = await self.loader(key)
if value is not None:
await self.cache.setex(key, self.default_ttl, json.dumps(value))
return value
2. 失效策略
from enum import Enum
from typing import Set, List
import fnmatch
class InvalidationStrategy(Enum):
TIME_BASED = "time_based"
EVENT_BASED = "event_based"
VERSION_BASED = "version_based"
class CacheInvalidator:
def __init__(self, cache_client):
self.cache = cache_client
self._tag_index: Dict[str, Set[str]] = defaultdict(set)
# 基于标签的失效
async def set_with_tags(
self,
key: str,
value: Any,
tags: List[str],
ttl: int = 3600
):
await self.cache.setex(key, ttl, json.dumps(value))
for tag in tags:
self._tag_index[tag].add(key)
await self.cache.sadd(f"tag:{tag}", key)
async def invalidate_by_tag(self, tag: str):
keys = await self.cache.smembers(f"tag:{tag}")
if keys:
await self.cache.delete(*keys)
await self.cache.delete(f"tag:{tag}")
# 基于模式的失效
async def invalidate_by_pattern(self, pattern: str):
cursor = 0
while True:
cursor, keys = await self.cache.scan(
cursor,
match=pattern,
count=100
)
if keys:
await self.cache.delete(*keys)
if cursor == 0:
break
# 基于版本的失效
async def get_versioned(self, key: str, version: int) -> Optional[Any]:
versioned_key = f"{key}:v{version}"
return await self.cache.get(versioned_key)
async def set_versioned(
self,
key: str,
value: Any,
version: int,
ttl: int = 3600
):
versioned_key = f"{key}:v{version}"
await self.cache.setex(versioned_key, ttl, json.dumps(value))
# 基于事件的失效,使用发布/订阅
class EventBasedInvalidator:
def __init__(self, cache_client, pubsub_client):
self.cache = cache_client
self.pubsub = pubsub_client
async def start_listener(self):
await self.pubsub.subscribe("cache:invalidate")
async for message in self.pubsub.listen():
if message["type"] == "message":
data = json.loads(message["data"])
await self._handle_invalidation(data)
async def _handle_invalidation(self, data: dict):
if "key" in data:
await self.cache.delete(data["key"])
elif "pattern" in data:
await self.invalidate_by_pattern(data["pattern"])
elif "tag" in data:
await self.invalidate_by_tag(data["tag"])
async def publish_invalidation(self, **kwargs):
await self.pubsub.publish(
"cache:invalidate",
json.dumps(kwargs)
)
3. TTL和过期
import random
from datetime import datetime, timedelta
class TTLManager:
def __init__(self, base_ttl: int = 3600):
self.base_ttl = base_ttl
def get_ttl_with_jitter(self, ttl: Optional[int] = None) -> int:
"""添加随机性以防止同步过期。"""
base = ttl or self.base_ttl
jitter = random.uniform(-0.1, 0.1) * base
return int(base + jitter)
def get_sliding_ttl(
self,
last_access: datetime,
min_ttl: int = 60,
max_ttl: int = 3600
) -> int:
"""基于访问频率的TTL。"""
age = (datetime.utcnow() - last_access).total_seconds()
if age < 60:
return max_ttl # 频繁访问
elif age < 300:
return max_ttl // 2
else:
return min_ttl
def get_tiered_ttl(self, data_type: str) -> int:
"""不同数据类型的不同TTL。"""
ttl_tiers = {
"user_session": 86400, # 1天
"user_profile": 3600, # 1小时
"product_catalog": 300, # 5分钟
"search_results": 60, # 1分钟
"real_time_data": 10, # 10秒
}
return ttl_tiers.get(data_type, self.base_ttl)
# 滑动窗口过期
class SlidingCache:
def __init__(self, cache_client, default_ttl: int = 3600):
self.cache = cache_client
self.default_ttl = default_ttl
async def get(self, key: str) -> Optional[Any]:
value = await self.cache.get(key)
if value:
# 访问时刷新TTL
await self.cache.expire(key, self.default_ttl)
return json.loads(value)
return None
async def set(self, key: str, value: Any, ttl: Optional[int] = None):
await self.cache.setex(
key,
ttl or self.default_ttl,
json.dumps(value)
)
4. Redis/Memcached模式
import redis.asyncio as redis
from typing import List, Tuple
class RedisCache:
def __init__(self, redis_url: str):
self.redis = redis.from_url(redis_url)
# 哈希操作,用于结构化数据
async def set_hash(self, key: str, data: dict, ttl: int = 3600):
await self.redis.hset(key, mapping=data)
await self.redis.expire(key, ttl)
async def get_hash(self, key: str) -> Optional[dict]:
data = await self.redis.hgetall(key)
return {k.decode(): v.decode() for k, v in data.items()} if data else None
async def update_hash_field(self, key: str, field: str, value: str):
await self.redis.hset(key, field, value)
# 有序集合,用于排行榜/排名
async def add_to_ranking(self, key: str, member: str, score: float):
await self.redis.zadd(key, {member: score})
async def get_top_n(self, key: str, n: int) -> List[Tuple[str, float]]:
return await self.redis.zrevrange(key, 0, n - 1, withscores=True)
# 列表,用于队列
async def push_to_queue(self, key: str, *values):
await self.redis.lpush(key, *values)
async def pop_from_queue(self, key: str, timeout: int = 0):
return await self.redis.brpop(key, timeout=timeout)
# 发布/订阅,用于缓存失效
async def publish(self, channel: str, message: str):
await self.redis.publish(channel, message)
# Lua脚本,用于原子操作
async def increment_with_cap(self, key: str, cap: int) -> int:
script = """
local current = redis.call('GET', KEYS[1])
if current and tonumber(current) >= tonumber(ARGV[1]) then
return -1
end
return redis.call('INCR', KEYS[1])
"""
return await self.redis.eval(script, 1, key, cap)
# 管道,用于批量操作
async def batch_get(self, keys: List[str]) -> List[Optional[str]]:
async with self.redis.pipeline() as pipe:
for key in keys:
pipe.get(key)
return await pipe.execute()
async def batch_set(
self,
items: List[Tuple[str, str]],
ttl: int = 3600
):
async with self.redis.pipeline() as pipe:
for key, value in items:
pipe.setex(key, ttl, value)
await pipe.execute()
5. 缓存击穿预防
import asyncio
import hashlib
import time
from typing import Optional, Callable
class StampedeProtection:
"""使用多种策略预防缓存击穿。"""
def __init__(self, cache_client):
self.cache = cache_client
self._locks: Dict[str, asyncio.Lock] = {}
# 策略1: 锁定(防止并发再生)
async def get_with_lock(
self,
key: str,
loader: Callable,
ttl: int = 3600
) -> Any:
cached = await self.cache.get(key)
if cached:
return json.loads(cached)
# 获取或创建此键的锁
if key not in self._locks:
self._locks[key] = asyncio.Lock()
async with self._locks[key]:
# 获取锁后再次检查
cached = await self.cache.get(key)
if cached:
return json.loads(cached)
value = await loader()
await self.cache.setex(key, ttl, json.dumps(value))
return value
# 策略2: 概率早期过期
async def get_with_early_recompute(
self,
key: str,
loader: Callable,
ttl: int = 3600,
beta: float = 1.0
) -> Any:
data = await self.cache.get(key)
if data:
cached = json.loads(data)
expiry = cached.get("_expiry", 0)
delta = cached.get("_delta", 0)
# 概率早期重新计算
now = time.time()
if now - delta * beta * random.random() < expiry:
return cached["value"]
# 重新计算
start = time.time()
value = await loader()
delta = time.time() - start
cache_data = {
"value": value,
"_expiry": time.time() + ttl,
"_delta": delta
}
await self.cache.setex(key, ttl, json.dumps(cache_data))
return value
# 策略3: 陈旧数据更新
async def get_stale_while_revalidate(
self,
key: str,
loader: Callable,
ttl: int = 3600,
stale_ttl: int = 300
) -> Any:
data = await self.cache.get(key)
if data:
cached = json.loads(data)
if cached.get("_fresh", True):
return cached["value"]
# 返回陈旧数据,后台更新
asyncio.create_task(self._revalidate(key, loader, ttl, stale_ttl))
return cached["value"]
return await self._load_and_cache(key, loader, ttl, stale_ttl)
async def _revalidate(
self,
key: str,
loader: Callable,
ttl: int,
stale_ttl: int
):
lock_key = f"lock:{key}"
acquired = await self.cache.setnx(lock_key, "1")
if acquired:
try:
await self.cache.expire(lock_key, 30)
await self._load_and_cache(key, loader, ttl, stale_ttl)
finally:
await self.cache.delete(lock_key)
async def _load_and_cache(
self,
key: str,
loader: Callable,
ttl: int,
stale_ttl: int
):
value = await loader()
cache_data = {"value": value, "_fresh": True}
await self.cache.setex(key, ttl, json.dumps(cache_data))
# TTL后标记为陈旧
async def mark_stale():
await asyncio.sleep(ttl - stale_ttl)
data = await self.cache.get(key)
if data:
cached = json.loads(data)
cached["_fresh"] = False
await self.cache.setex(key, stale_ttl, json.dumps(cached))
asyncio.create_task(mark_stale())
return value
6. 分布式缓存
import hashlib
from typing import List, Optional
import random
class ConsistentHashing:
"""用于分布式缓存节点的一致性哈希。"""
def __init__(self, nodes: List[str], replicas: int = 100):
self.replicas = replicas
self.ring: Dict[int, str] = {}
self.sorted_keys: List[int] = []
for node in nodes:
self.add_node(node)
def _hash(self, key: str) -> int:
return int(hashlib.md5(key.encode()).hexdigest(), 16)
def add_node(self, node: str):
for i in range(self.replicas):
key = self._hash(f"{node}:{i}")
self.ring[key] = node
self.sorted_keys.append(key)
self.sorted_keys.sort()
def remove_node(self, node: str):
for i in range(self.replicas):
key = self._hash(f"{node}:{i}")
del self.ring[key]
self.sorted_keys.remove(key)
def get_node(self, key: str) -> str:
if not self.ring:
raise ValueError("环中没有节点")
h = self._hash(key)
for ring_key in self.sorted_keys:
if h <= ring_key:
return self.ring[ring_key]
return self.ring[self.sorted_keys[0]]
class DistributedCache:
"""使用一致性哈希的分布式缓存。"""
def __init__(self, nodes: List[str]):
self.ring = ConsistentHashing(nodes)
self.clients: Dict[str, RedisCache] = {
node: RedisCache(node) for node in nodes
}
def _get_client(self, key: str) -> RedisCache:
node = self.ring.get_node(key)
return self.clients[node]
async def get(self, key: str) -> Optional[str]:
client = self._get_client(key)
return await client.redis.get(key)
async def set(self, key: str, value: str, ttl: int = 3600):
client = self._get_client(key)
await client.redis.setex(key, ttl, value)
async def delete(self, key: str):
client = self._get_client(key)
await client.redis.delete(key)
# 跨节点多键获取
async def mget(self, keys: List[str]) -> Dict[str, Optional[str]]:
# 按节点分组键
node_keys: Dict[str, List[str]] = defaultdict(list)
for key in keys:
node = self.ring.get_node(key)
node_keys[node].append(key)
# 并行获取
results = {}
tasks = []
for node, node_key_list in node_keys.items():
client = self.clients[node]
tasks.append(self._fetch_from_node(client, node_key_list, results))
await asyncio.gather(*tasks)
return results
async def _fetch_from_node(
self,
client: RedisCache,
keys: List[str],
results: Dict
):
values = await client.batch_get(keys)
for key, value in zip(keys, values):
results[key] = value
7. 数据库查询缓存
import sqlalchemy
from typing import Optional, List, Any
import hashlib
class QueryCache:
"""数据库查询结果缓存,带自动失效。"""
def __init__(self, cache_client, default_ttl: int = 300):
self.cache = cache_client
self.default_ttl = default_ttl
def _query_key(self, sql: str, params: tuple) -> str:
"""从SQL和参数生成缓存键。"""
query_str = f"{sql}:{params}"
return f"query:{hashlib.md5(query_str.encode()).hexdigest()}"
async def execute_cached(
self,
session,
query,
params: Optional[dict] = None,
ttl: Optional[int] = None
) -> List[Any]:
sql_str = str(query)
cache_key = self._query_key(sql_str, tuple(sorted((params or {}).items())))
# 检查缓存
cached = await self.cache.get(cache_key)
if cached:
return json.loads(cached)
# 执行查询
result = session.execute(query, params).fetchall()
serialized = [dict(row) for row in result]
# 缓存结果
await self.cache.setex(
cache_key,
ttl or self.default_ttl,
json.dumps(serialized)
)
return serialized
async def invalidate_table(self, table_name: str):
"""使特定表的所有查询失效。"""
pattern = f"query:*{table_name}*"
await self.cache.delete_pattern(pattern)
# 使用SQLAlchemy的ORM级缓存
from sqlalchemy import event
from sqlalchemy.orm import Session
class ORMCache:
def __init__(self, cache_client):
self.cache = cache_client
def setup_listeners(self, engine):
"""设置写操作时的自动缓存失效。"""
@event.listens_for(Session, "after_flush")
def receive_after_flush(session, flush_context):
# 使修改表的缓存失效
for obj in session.dirty | session.new | session.deleted:
table = obj.__tablename__
asyncio.create_task(
self.cache.delete_pattern(f"query:*{table}*")
)
async def get_by_id(self, model, obj_id: int, session):
key = f"{model.__tablename__}:{obj_id}"
cached = await self.cache.get(key)
if cached:
return json.loads(cached)
obj = session.query(model).get(obj_id)
if obj:
await self.cache.setex(key, 3600, json.dumps(obj.to_dict()))
return obj
8. CDN和静态资源缓存
from datetime import datetime, timedelta
from typing import Dict, Optional
class CDNCache:
"""使用适当头部的CDN缓存策略。"""
@staticmethod
def get_cache_headers(
cache_type: str,
max_age: int = 3600
) -> Dict[str, str]:
"""生成适当的HTTP缓存头部。"""
strategies = {
"immutable": {
"Cache-Control": f"public, max-age={max_age}, immutable",
"Expires": (datetime.utcnow() + timedelta(seconds=max_age)).strftime(
"%a, %d %b %Y %H:%M:%S GMT"
),
},
"versioned": {
"Cache-Control": f"public, max-age={max_age}",
"ETag": None, # 动态设置
},
"revalidate": {
"Cache-Control": "public, max-age=0, must-revalidate",
"ETag": None, # 动态设置
},
"private": {
"Cache-Control": "private, max-age=0, must-revalidate",
"Pragma": "no-cache",
},
"no-cache": {
"Cache-Control": "no-store, no-cache, must-revalidate, max-age=0",
"Pragma": "no-cache",
"Expires": "0",
},
}
return strategies.get(cache_type, strategies["revalidate"])
@staticmethod
def generate_etag(content: bytes) -> str:
"""从内容哈希生成ETag。"""
return hashlib.md5(content).hexdigest()
# FastAPI/Flask示例
from fastapi import Response
from fastapi.responses import FileResponse
class StaticAssetCache:
def __init__(self, cdn_base_url: Optional[str] = None):
self.cdn_base_url = cdn_base_url
def serve_asset(
self,
file_path: str,
asset_type: str = "immutable"
) -> FileResponse:
"""使用缓存头部提供静态资源。"""
with open(file_path, "rb") as f:
content = f.read()
headers = CDNCache.get_cache_headers(
asset_type,
max_age=31536000 if asset_type == "immutable" else 3600
)
# 为版本化资源添加ETag
if asset_type in ["versioned", "revalidate"]:
headers["ETag"] = f'"{CDNCache.generate_etag(content)}"'
return FileResponse(
file_path,
headers=headers,
media_type=self._get_media_type(file_path)
)
def get_asset_url(self, path: str, version: Optional[str] = None) -> str:
"""生成带可选版本化的CDN URL。"""
base = self.cdn_base_url or ""
if version:
# 为缓存破坏添加版本到文件名
parts = path.rsplit(".", 1)
path = f"{parts[0]}.{version}.{parts[1]}"
return f"{base}/{path}"
@staticmethod
def _get_media_type(file_path: str) -> str:
ext_to_mime = {
".js": "application/javascript",
".css": "text/css",
".png": "image/png",
".jpg": "image/jpeg",
".svg": "image/svg+xml",
".woff2": "font/woff2",
}
ext = file_path[file_path.rfind("."):]
return ext_to_mime.get(ext, "application/octet-stream")
9. 机器学习模型缓存
import pickle
from typing import Any, Callable
import numpy as np
class MLModelCache:
"""用于机器学习模型和预测的缓存。"""
def __init__(self, cache_client, model_store_path: str = "./models"):
self.cache = cache_client
self.model_store = model_store_path
async def get_model(self, model_id: str, version: str):
"""从缓存或磁盘加载模型。"""
key = f"model:{model_id}:v{version}"
cached = await self.cache.get(key)
if cached:
return pickle.loads(cached)
# 从磁盘加载
path = f"{self.model_store}/{model_id}/{version}/model.pkl"
with open(path, "rb") as f:
model = pickle.load(f)
# 缓存序列化模型(大模型使用压缩)
await self.cache.setex(
key,
86400, # 1天
pickle.dumps(model)
)
return model
async def cache_prediction(
self,
model_id: str,
input_hash: str,
prediction: Any,
ttl: int = 3600
):
"""缓存预测结果。"""
key = f"pred:{model_id}:{input_hash}"
await self.cache.setex(key, ttl, json.dumps(prediction))
async def get_cached_prediction(
self,
model_id: str,
input_data: Any
) -> Optional[Any]:
"""如果可用,检索缓存预测。"""
input_hash = hashlib.md5(
json.dumps(input_data, sort_keys=True).encode()
).hexdigest()
key = f"pred:{model_id}:{input_hash}"
cached = await self.cache.get(key)
return json.loads(cached) if cached else None
# 机器学习管道的特征缓存
class FeatureCache:
def __init__(self, cache_client):
self.cache = cache_client
async def get_features(
self,
entity_id: str,
feature_names: List[str],
compute_fn: Callable
) -> Dict[str, Any]:
"""从缓存获取或计算特征。"""
# 尝试从缓存获取
keys = [f"feature:{entity_id}:{name}" for name in feature_names]
cached_values = await self.cache.mget(keys)
features = {}
missing = []
for name, value in zip(feature_names, cached_values):
if value:
features[name] = json.loads(value)
else:
missing.append(name)
# 计算缺失特征
if missing:
computed = await compute_fn(entity_id, missing)
# 缓存计算的特征
for name, value in computed.items():
key = f"feature:{entity_id}:{name}"
await self.cache.setex(key, 3600, json.dumps(value))
features[name] = value
return features
# 向量相似性的嵌入缓存
class EmbeddingCache:
def __init__(self, cache_client):
self.cache = cache_client
async def get_embedding(
self,
text: str,
model: str,
embed_fn: Callable
) -> np.ndarray:
"""获取或计算文本嵌入。"""
text_hash = hashlib.md5(text.encode()).hexdigest()
key = f"embed:{model}:{text_hash}"
cached = await self.cache.get(key)
if cached:
return np.frombuffer(cached, dtype=np.float32)
embedding = await embed_fn(text)
# 存储为二进制
await self.cache.setex(
key,
86400 * 7, # 1周
embedding.tobytes()
)
return embedding
10. 本地内存缓存
from collections import OrderedDict
from threading import RLock
from typing import Optional
import time
class LRUCache:
"""带TTL支持的线程安全LRU缓存。"""
def __init__(self, capacity: int = 1000, default_ttl: Optional[int] = None):
self.cache = OrderedDict()
self.capacity = capacity
self.default_ttl = default_ttl
self.lock = RLock()
def get(self, key: str) -> Optional[Any]:
with self.lock:
if key not in self.cache:
return None
value, expiry = self.cache[key]
# 检查过期
if expiry and time.time() > expiry:
del self.cache[key]
return None
# 移动到末尾(最近使用)
self.cache.move_to_end(key)
return value
def put(self, key: str, value: Any, ttl: Optional[int] = None):
with self.lock:
if key in self.cache:
self.cache.move_to_end(key)
ttl = ttl or self.default_ttl
expiry = time.time() + ttl if ttl else None
self.cache[key] = (value, expiry)
if len(self.cache) > self.capacity:
# 移除最旧(最近最少使用)
self.cache.popitem(last=False)
def invalidate(self, key: str):
with self.lock:
self.cache.pop(key, None)
def clear(self):
with self.lock:
self.cache.clear()
# 函数记忆化装饰器
from functools import wraps
def memoize(ttl: Optional[int] = None, maxsize: int = 128):
"""带TTL的记忆化装饰器。"""
cache = LRUCache(capacity=maxsize, default_ttl=ttl)
def decorator(func: Callable):
@wraps(func)
def wrapper(*args, **kwargs):
# 从参数创建缓存键
key = f"{func.__name__}:{args}:{sorted(kwargs.items())}"
result = cache.get(key)
if result is not None:
return result
result = func(*args, **kwargs)
cache.put(key, result)
return result
wrapper.cache_clear = cache.clear
wrapper.cache_info = lambda: {
"size": len(cache.cache),
"maxsize": cache.capacity
}
return wrapper
return decorator
# 使用
@memoize(ttl=300, maxsize=1000)
def expensive_computation(x: int, y: int) -> int:
return x ** y
最佳实践
-
缓存重要内容:专注于频繁访问、计算成本高的数据。
-
设置适当TTL:平衡新鲜度和性能。使用抖动防止惊群效应。
-
优雅处理缓存故障:缓存应增强而非必需操作。
-
监控缓存性能:跟踪命中率、延迟和内存使用。
-
使用适当数据结构:对象用哈希,排名用有序集合,队列用列表。
-
实现正确失效:对于关键数据,优先事件驱动失效而非仅TTL。
-
预防击穿:使用锁定、早期重新计算或陈旧数据更新。
-
适当调整缓存大小:监控淘汰率并相应调整大小。
-
分层缓存:浏览器缓存 → CDN → Redis → 数据库以获得最佳性能。
-
安全考虑:未经加密绝不缓存敏感数据(密码、令牌、PII)。
示例
完整缓存层
class CacheLayer:
"""生产就绪的缓存层,具有多种策略。"""
def __init__(self, redis_url: str, default_ttl: int = 3600):
self.redis = RedisCache(redis_url)
self.stampede = StampedeProtection(self.redis.redis)
self.ttl_manager = TTLManager(default_ttl)
self.invalidator = CacheInvalidator(self.redis.redis)
async def get_user(self, user_id: str) -> User:
return await self.stampede.get_with_lock(
f"user:{user_id}",
lambda: self.database.get_user(user_id),
ttl=self.ttl_manager.get_tiered_ttl("user_profile")
)
async def update_user(self, user_id: str, data: dict):
await self.database.update_user(user_id, data)
await self.invalidator.invalidate_by_tag(f"user:{user_id}")
async def get_product_listing(self, category: str) -> List[Product]:
return await self.stampede.get_stale_while_revalidate(
f"products:{category}",
lambda: self.database.get_products(category),
ttl=self.ttl_manager.get_tiered_ttl("product_catalog"),
stale_ttl=60
)