缓存技术Skill caching

缓存技术是一种用于提升应用性能的关键技能,通过存储频繁访问的数据减少延迟和负载。涵盖缓存策略(如缓存旁路、直写、回写、直读)、失效模式(时间、事件、版本)、TTL管理、Redis/Memcached应用、缓存击穿预防、分布式缓存、数据库查询缓存、CDN优化和机器学习模型缓存等。适用于后端开发、系统架构、云计算和AI应用,关键词包括缓存优化、性能调优、Redis缓存、分布式系统、数据库查询加速、SEO性能、云计算缓存策略。

后端开发 0 次安装 0 次浏览 更新于 3/24/2026

名称: 缓存 描述: 全面的缓存策略和模式,用于性能优化。适用于实现缓存层、缓存失效、TTL策略或分布式缓存。涵盖Redis/Memcached模式、CDN缓存、数据库查询缓存、机器学习模型缓存和淘汰策略。触发器: cache, caching, Redis, Memcached, CDN, TTL, invalidation, eviction, LRU, LFU, FIFO, write-through, write-behind, cache-aside, read-through, cache stampede, distributed cache, local cache, memoization, query cache, result cache, edge cache, browser cache, HTTP cache。

缓存

概述

缓存通过将频繁访问的数据存储在更接近消费者的位置来提高应用程序性能。此技能涵盖缓存策略(旁路、直写、回写)、失效模式、TTL管理、Redis/Memcached使用、击穿预防和分布式缓存。

指令

1. 缓存策略

缓存旁路(惰性加载)

from typing import TypeVar, Optional, Callable
import json

T = TypeVar('T')

class CacheAside:
    """应用程序显式管理缓存。"""

    def __init__(self, cache_client, default_ttl: int = 3600):
        self.cache = cache_client
        self.default_ttl = default_ttl

    async def get_or_load(
        self,
        key: str,
        loader: Callable[[], T],
        ttl: Optional[int] = None
    ) -> T:
        # 首先尝试缓存
        cached = await self.cache.get(key)
        if cached is not None:
            return json.loads(cached)

        # 从源加载
        value = await loader()

        # 存储在缓存中
        await self.cache.setex(
            key,
            ttl or self.default_ttl,
            json.dumps(value)
        )

        return value

    async def invalidate(self, key: str):
        await self.cache.delete(key)

# 使用
cache = CacheAside(redis_client)

async def get_user(user_id: str) -> User:
    return await cache.get_or_load(
        f"user:{user_id}",
        lambda: database.get_user(user_id),
        ttl=300
    )

直写缓存

class WriteThrough:
    """写操作同时更新缓存和数据库。"""

    def __init__(self, cache_client, database, default_ttl: int = 3600):
        self.cache = cache_client
        self.database = database
        self.default_ttl = default_ttl

    async def write(self, key: str, value: any, ttl: Optional[int] = None):
        # 首先写数据库
        await self.database.save(key, value)

        # 然后更新缓存
        await self.cache.setex(
            key,
            ttl or self.default_ttl,
            json.dumps(value)
        )

    async def read(self, key: str) -> Optional[any]:
        cached = await self.cache.get(key)
        if cached:
            return json.loads(cached)

        # 回退到数据库
        value = await self.database.get(key)
        if value:
            await self.cache.setex(key, self.default_ttl, json.dumps(value))
        return value

回写(写回)缓存

import asyncio
from collections import defaultdict
from dataclasses import dataclass
from typing import Dict, Any
import time

@dataclass
class PendingWrite:
    key: str
    value: Any
    timestamp: float

class WriteBehind:
    """写操作立即更新缓存,异步更新数据库。"""

    def __init__(
        self,
        cache_client,
        database,
        flush_interval: float = 5.0,
        batch_size: int = 100
    ):
        self.cache = cache_client
        self.database = database
        self.flush_interval = flush_interval
        self.batch_size = batch_size
        self.pending_writes: Dict[str, PendingWrite] = {}
        self._lock = asyncio.Lock()
        self._flush_task = None

    async def start(self):
        self._flush_task = asyncio.create_task(self._flush_loop())

    async def stop(self):
        if self._flush_task:
            self._flush_task.cancel()
            await self._flush_pending()

    async def write(self, key: str, value: Any):
        # 立即写缓存
        await self.cache.set(key, json.dumps(value))

        # 排队数据库写
        async with self._lock:
            self.pending_writes[key] = PendingWrite(
                key=key,
                value=value,
                timestamp=time.time()
            )

        if len(self.pending_writes) >= self.batch_size:
            await self._flush_pending()

    async def _flush_loop(self):
        while True:
            await asyncio.sleep(self.flush_interval)
            await self._flush_pending()

    async def _flush_pending(self):
        async with self._lock:
            if not self.pending_writes:
                return

            writes = list(self.pending_writes.values())
            self.pending_writes.clear()

        # 批量写数据库
        await self.database.batch_save(
            [(w.key, w.value) for w in writes]
        )

直读缓存

class ReadThrough:
    """缓存错过时自动从数据库加载。"""

    def __init__(self, cache_client, loader, default_ttl: int = 3600):
        self.cache = cache_client
        self.loader = loader
        self.default_ttl = default_ttl

    async def get(self, key: str) -> Optional[Any]:
        # 检查缓存
        cached = await self.cache.get(key)
        if cached:
            return json.loads(cached)

        # 错过时自动加载
        value = await self.loader(key)
        if value is not None:
            await self.cache.setex(key, self.default_ttl, json.dumps(value))

        return value

2. 失效策略

from enum import Enum
from typing import Set, List
import fnmatch

class InvalidationStrategy(Enum):
    TIME_BASED = "time_based"
    EVENT_BASED = "event_based"
    VERSION_BASED = "version_based"

class CacheInvalidator:
    def __init__(self, cache_client):
        self.cache = cache_client
        self._tag_index: Dict[str, Set[str]] = defaultdict(set)

    # 基于标签的失效
    async def set_with_tags(
        self,
        key: str,
        value: Any,
        tags: List[str],
        ttl: int = 3600
    ):
        await self.cache.setex(key, ttl, json.dumps(value))

        for tag in tags:
            self._tag_index[tag].add(key)
            await self.cache.sadd(f"tag:{tag}", key)

    async def invalidate_by_tag(self, tag: str):
        keys = await self.cache.smembers(f"tag:{tag}")
        if keys:
            await self.cache.delete(*keys)
            await self.cache.delete(f"tag:{tag}")

    # 基于模式的失效
    async def invalidate_by_pattern(self, pattern: str):
        cursor = 0
        while True:
            cursor, keys = await self.cache.scan(
                cursor,
                match=pattern,
                count=100
            )
            if keys:
                await self.cache.delete(*keys)
            if cursor == 0:
                break

    # 基于版本的失效
    async def get_versioned(self, key: str, version: int) -> Optional[Any]:
        versioned_key = f"{key}:v{version}"
        return await self.cache.get(versioned_key)

    async def set_versioned(
        self,
        key: str,
        value: Any,
        version: int,
        ttl: int = 3600
    ):
        versioned_key = f"{key}:v{version}"
        await self.cache.setex(versioned_key, ttl, json.dumps(value))

# 基于事件的失效,使用发布/订阅
class EventBasedInvalidator:
    def __init__(self, cache_client, pubsub_client):
        self.cache = cache_client
        self.pubsub = pubsub_client

    async def start_listener(self):
        await self.pubsub.subscribe("cache:invalidate")

        async for message in self.pubsub.listen():
            if message["type"] == "message":
                data = json.loads(message["data"])
                await self._handle_invalidation(data)

    async def _handle_invalidation(self, data: dict):
        if "key" in data:
            await self.cache.delete(data["key"])
        elif "pattern" in data:
            await self.invalidate_by_pattern(data["pattern"])
        elif "tag" in data:
            await self.invalidate_by_tag(data["tag"])

    async def publish_invalidation(self, **kwargs):
        await self.pubsub.publish(
            "cache:invalidate",
            json.dumps(kwargs)
        )

3. TTL和过期

import random
from datetime import datetime, timedelta

class TTLManager:
    def __init__(self, base_ttl: int = 3600):
        self.base_ttl = base_ttl

    def get_ttl_with_jitter(self, ttl: Optional[int] = None) -> int:
        """添加随机性以防止同步过期。"""
        base = ttl or self.base_ttl
        jitter = random.uniform(-0.1, 0.1) * base
        return int(base + jitter)

    def get_sliding_ttl(
        self,
        last_access: datetime,
        min_ttl: int = 60,
        max_ttl: int = 3600
    ) -> int:
        """基于访问频率的TTL。"""
        age = (datetime.utcnow() - last_access).total_seconds()

        if age < 60:
            return max_ttl  # 频繁访问
        elif age < 300:
            return max_ttl // 2
        else:
            return min_ttl

    def get_tiered_ttl(self, data_type: str) -> int:
        """不同数据类型的不同TTL。"""
        ttl_tiers = {
            "user_session": 86400,      # 1天
            "user_profile": 3600,       # 1小时
            "product_catalog": 300,     # 5分钟
            "search_results": 60,       # 1分钟
            "real_time_data": 10,       # 10秒
        }
        return ttl_tiers.get(data_type, self.base_ttl)

# 滑动窗口过期
class SlidingCache:
    def __init__(self, cache_client, default_ttl: int = 3600):
        self.cache = cache_client
        self.default_ttl = default_ttl

    async def get(self, key: str) -> Optional[Any]:
        value = await self.cache.get(key)
        if value:
            # 访问时刷新TTL
            await self.cache.expire(key, self.default_ttl)
            return json.loads(value)
        return None

    async def set(self, key: str, value: Any, ttl: Optional[int] = None):
        await self.cache.setex(
            key,
            ttl or self.default_ttl,
            json.dumps(value)
        )

4. Redis/Memcached模式

import redis.asyncio as redis
from typing import List, Tuple

class RedisCache:
    def __init__(self, redis_url: str):
        self.redis = redis.from_url(redis_url)

    # 哈希操作,用于结构化数据
    async def set_hash(self, key: str, data: dict, ttl: int = 3600):
        await self.redis.hset(key, mapping=data)
        await self.redis.expire(key, ttl)

    async def get_hash(self, key: str) -> Optional[dict]:
        data = await self.redis.hgetall(key)
        return {k.decode(): v.decode() for k, v in data.items()} if data else None

    async def update_hash_field(self, key: str, field: str, value: str):
        await self.redis.hset(key, field, value)

    # 有序集合,用于排行榜/排名
    async def add_to_ranking(self, key: str, member: str, score: float):
        await self.redis.zadd(key, {member: score})

    async def get_top_n(self, key: str, n: int) -> List[Tuple[str, float]]:
        return await self.redis.zrevrange(key, 0, n - 1, withscores=True)

    # 列表,用于队列
    async def push_to_queue(self, key: str, *values):
        await self.redis.lpush(key, *values)

    async def pop_from_queue(self, key: str, timeout: int = 0):
        return await self.redis.brpop(key, timeout=timeout)

    # 发布/订阅,用于缓存失效
    async def publish(self, channel: str, message: str):
        await self.redis.publish(channel, message)

    # Lua脚本,用于原子操作
    async def increment_with_cap(self, key: str, cap: int) -> int:
        script = """
        local current = redis.call('GET', KEYS[1])
        if current and tonumber(current) >= tonumber(ARGV[1]) then
            return -1
        end
        return redis.call('INCR', KEYS[1])
        """
        return await self.redis.eval(script, 1, key, cap)

    # 管道,用于批量操作
    async def batch_get(self, keys: List[str]) -> List[Optional[str]]:
        async with self.redis.pipeline() as pipe:
            for key in keys:
                pipe.get(key)
            return await pipe.execute()

    async def batch_set(
        self,
        items: List[Tuple[str, str]],
        ttl: int = 3600
    ):
        async with self.redis.pipeline() as pipe:
            for key, value in items:
                pipe.setex(key, ttl, value)
            await pipe.execute()

5. 缓存击穿预防

import asyncio
import hashlib
import time
from typing import Optional, Callable

class StampedeProtection:
    """使用多种策略预防缓存击穿。"""

    def __init__(self, cache_client):
        self.cache = cache_client
        self._locks: Dict[str, asyncio.Lock] = {}

    # 策略1: 锁定(防止并发再生)
    async def get_with_lock(
        self,
        key: str,
        loader: Callable,
        ttl: int = 3600
    ) -> Any:
        cached = await self.cache.get(key)
        if cached:
            return json.loads(cached)

        # 获取或创建此键的锁
        if key not in self._locks:
            self._locks[key] = asyncio.Lock()

        async with self._locks[key]:
            # 获取锁后再次检查
            cached = await self.cache.get(key)
            if cached:
                return json.loads(cached)

            value = await loader()
            await self.cache.setex(key, ttl, json.dumps(value))
            return value

    # 策略2: 概率早期过期
    async def get_with_early_recompute(
        self,
        key: str,
        loader: Callable,
        ttl: int = 3600,
        beta: float = 1.0
    ) -> Any:
        data = await self.cache.get(key)

        if data:
            cached = json.loads(data)
            expiry = cached.get("_expiry", 0)
            delta = cached.get("_delta", 0)

            # 概率早期重新计算
            now = time.time()
            if now - delta * beta * random.random() < expiry:
                return cached["value"]

        # 重新计算
        start = time.time()
        value = await loader()
        delta = time.time() - start

        cache_data = {
            "value": value,
            "_expiry": time.time() + ttl,
            "_delta": delta
        }
        await self.cache.setex(key, ttl, json.dumps(cache_data))
        return value

    # 策略3: 陈旧数据更新
    async def get_stale_while_revalidate(
        self,
        key: str,
        loader: Callable,
        ttl: int = 3600,
        stale_ttl: int = 300
    ) -> Any:
        data = await self.cache.get(key)

        if data:
            cached = json.loads(data)

            if cached.get("_fresh", True):
                return cached["value"]

            # 返回陈旧数据,后台更新
            asyncio.create_task(self._revalidate(key, loader, ttl, stale_ttl))
            return cached["value"]

        return await self._load_and_cache(key, loader, ttl, stale_ttl)

    async def _revalidate(
        self,
        key: str,
        loader: Callable,
        ttl: int,
        stale_ttl: int
    ):
        lock_key = f"lock:{key}"
        acquired = await self.cache.setnx(lock_key, "1")

        if acquired:
            try:
                await self.cache.expire(lock_key, 30)
                await self._load_and_cache(key, loader, ttl, stale_ttl)
            finally:
                await self.cache.delete(lock_key)

    async def _load_and_cache(
        self,
        key: str,
        loader: Callable,
        ttl: int,
        stale_ttl: int
    ):
        value = await loader()

        cache_data = {"value": value, "_fresh": True}
        await self.cache.setex(key, ttl, json.dumps(cache_data))

        # TTL后标记为陈旧
        async def mark_stale():
            await asyncio.sleep(ttl - stale_ttl)
            data = await self.cache.get(key)
            if data:
                cached = json.loads(data)
                cached["_fresh"] = False
                await self.cache.setex(key, stale_ttl, json.dumps(cached))

        asyncio.create_task(mark_stale())
        return value

6. 分布式缓存

import hashlib
from typing import List, Optional
import random

class ConsistentHashing:
    """用于分布式缓存节点的一致性哈希。"""

    def __init__(self, nodes: List[str], replicas: int = 100):
        self.replicas = replicas
        self.ring: Dict[int, str] = {}
        self.sorted_keys: List[int] = []

        for node in nodes:
            self.add_node(node)

    def _hash(self, key: str) -> int:
        return int(hashlib.md5(key.encode()).hexdigest(), 16)

    def add_node(self, node: str):
        for i in range(self.replicas):
            key = self._hash(f"{node}:{i}")
            self.ring[key] = node
            self.sorted_keys.append(key)
        self.sorted_keys.sort()

    def remove_node(self, node: str):
        for i in range(self.replicas):
            key = self._hash(f"{node}:{i}")
            del self.ring[key]
            self.sorted_keys.remove(key)

    def get_node(self, key: str) -> str:
        if not self.ring:
            raise ValueError("环中没有节点")

        h = self._hash(key)

        for ring_key in self.sorted_keys:
            if h <= ring_key:
                return self.ring[ring_key]

        return self.ring[self.sorted_keys[0]]

class DistributedCache:
    """使用一致性哈希的分布式缓存。"""

    def __init__(self, nodes: List[str]):
        self.ring = ConsistentHashing(nodes)
        self.clients: Dict[str, RedisCache] = {
            node: RedisCache(node) for node in nodes
        }

    def _get_client(self, key: str) -> RedisCache:
        node = self.ring.get_node(key)
        return self.clients[node]

    async def get(self, key: str) -> Optional[str]:
        client = self._get_client(key)
        return await client.redis.get(key)

    async def set(self, key: str, value: str, ttl: int = 3600):
        client = self._get_client(key)
        await client.redis.setex(key, ttl, value)

    async def delete(self, key: str):
        client = self._get_client(key)
        await client.redis.delete(key)

    # 跨节点多键获取
    async def mget(self, keys: List[str]) -> Dict[str, Optional[str]]:
        # 按节点分组键
        node_keys: Dict[str, List[str]] = defaultdict(list)
        for key in keys:
            node = self.ring.get_node(key)
            node_keys[node].append(key)

        # 并行获取
        results = {}
        tasks = []

        for node, node_key_list in node_keys.items():
            client = self.clients[node]
            tasks.append(self._fetch_from_node(client, node_key_list, results))

        await asyncio.gather(*tasks)
        return results

    async def _fetch_from_node(
        self,
        client: RedisCache,
        keys: List[str],
        results: Dict
    ):
        values = await client.batch_get(keys)
        for key, value in zip(keys, values):
            results[key] = value

7. 数据库查询缓存

import sqlalchemy
from typing import Optional, List, Any
import hashlib

class QueryCache:
    """数据库查询结果缓存,带自动失效。"""

    def __init__(self, cache_client, default_ttl: int = 300):
        self.cache = cache_client
        self.default_ttl = default_ttl

    def _query_key(self, sql: str, params: tuple) -> str:
        """从SQL和参数生成缓存键。"""
        query_str = f"{sql}:{params}"
        return f"query:{hashlib.md5(query_str.encode()).hexdigest()}"

    async def execute_cached(
        self,
        session,
        query,
        params: Optional[dict] = None,
        ttl: Optional[int] = None
    ) -> List[Any]:
        sql_str = str(query)
        cache_key = self._query_key(sql_str, tuple(sorted((params or {}).items())))

        # 检查缓存
        cached = await self.cache.get(cache_key)
        if cached:
            return json.loads(cached)

        # 执行查询
        result = session.execute(query, params).fetchall()
        serialized = [dict(row) for row in result]

        # 缓存结果
        await self.cache.setex(
            cache_key,
            ttl or self.default_ttl,
            json.dumps(serialized)
        )

        return serialized

    async def invalidate_table(self, table_name: str):
        """使特定表的所有查询失效。"""
        pattern = f"query:*{table_name}*"
        await self.cache.delete_pattern(pattern)

# 使用SQLAlchemy的ORM级缓存
from sqlalchemy import event
from sqlalchemy.orm import Session

class ORMCache:
    def __init__(self, cache_client):
        self.cache = cache_client

    def setup_listeners(self, engine):
        """设置写操作时的自动缓存失效。"""

        @event.listens_for(Session, "after_flush")
        def receive_after_flush(session, flush_context):
            # 使修改表的缓存失效
            for obj in session.dirty | session.new | session.deleted:
                table = obj.__tablename__
                asyncio.create_task(
                    self.cache.delete_pattern(f"query:*{table}*")
                )

    async def get_by_id(self, model, obj_id: int, session):
        key = f"{model.__tablename__}:{obj_id}"
        cached = await self.cache.get(key)

        if cached:
            return json.loads(cached)

        obj = session.query(model).get(obj_id)
        if obj:
            await self.cache.setex(key, 3600, json.dumps(obj.to_dict()))

        return obj

8. CDN和静态资源缓存

from datetime import datetime, timedelta
from typing import Dict, Optional

class CDNCache:
    """使用适当头部的CDN缓存策略。"""

    @staticmethod
    def get_cache_headers(
        cache_type: str,
        max_age: int = 3600
    ) -> Dict[str, str]:
        """生成适当的HTTP缓存头部。"""

        strategies = {
            "immutable": {
                "Cache-Control": f"public, max-age={max_age}, immutable",
                "Expires": (datetime.utcnow() + timedelta(seconds=max_age)).strftime(
                    "%a, %d %b %Y %H:%M:%S GMT"
                ),
            },
            "versioned": {
                "Cache-Control": f"public, max-age={max_age}",
                "ETag": None,  # 动态设置
            },
            "revalidate": {
                "Cache-Control": "public, max-age=0, must-revalidate",
                "ETag": None,  # 动态设置
            },
            "private": {
                "Cache-Control": "private, max-age=0, must-revalidate",
                "Pragma": "no-cache",
            },
            "no-cache": {
                "Cache-Control": "no-store, no-cache, must-revalidate, max-age=0",
                "Pragma": "no-cache",
                "Expires": "0",
            },
        }

        return strategies.get(cache_type, strategies["revalidate"])

    @staticmethod
    def generate_etag(content: bytes) -> str:
        """从内容哈希生成ETag。"""
        return hashlib.md5(content).hexdigest()

# FastAPI/Flask示例
from fastapi import Response
from fastapi.responses import FileResponse

class StaticAssetCache:
    def __init__(self, cdn_base_url: Optional[str] = None):
        self.cdn_base_url = cdn_base_url

    def serve_asset(
        self,
        file_path: str,
        asset_type: str = "immutable"
    ) -> FileResponse:
        """使用缓存头部提供静态资源。"""

        with open(file_path, "rb") as f:
            content = f.read()

        headers = CDNCache.get_cache_headers(
            asset_type,
            max_age=31536000 if asset_type == "immutable" else 3600
        )

        # 为版本化资源添加ETag
        if asset_type in ["versioned", "revalidate"]:
            headers["ETag"] = f'"{CDNCache.generate_etag(content)}"'

        return FileResponse(
            file_path,
            headers=headers,
            media_type=self._get_media_type(file_path)
        )

    def get_asset_url(self, path: str, version: Optional[str] = None) -> str:
        """生成带可选版本化的CDN URL。"""
        base = self.cdn_base_url or ""

        if version:
            # 为缓存破坏添加版本到文件名
            parts = path.rsplit(".", 1)
            path = f"{parts[0]}.{version}.{parts[1]}"

        return f"{base}/{path}"

    @staticmethod
    def _get_media_type(file_path: str) -> str:
        ext_to_mime = {
            ".js": "application/javascript",
            ".css": "text/css",
            ".png": "image/png",
            ".jpg": "image/jpeg",
            ".svg": "image/svg+xml",
            ".woff2": "font/woff2",
        }
        ext = file_path[file_path.rfind("."):]
        return ext_to_mime.get(ext, "application/octet-stream")

9. 机器学习模型缓存

import pickle
from typing import Any, Callable
import numpy as np

class MLModelCache:
    """用于机器学习模型和预测的缓存。"""

    def __init__(self, cache_client, model_store_path: str = "./models"):
        self.cache = cache_client
        self.model_store = model_store_path

    async def get_model(self, model_id: str, version: str):
        """从缓存或磁盘加载模型。"""
        key = f"model:{model_id}:v{version}"

        cached = await self.cache.get(key)
        if cached:
            return pickle.loads(cached)

        # 从磁盘加载
        path = f"{self.model_store}/{model_id}/{version}/model.pkl"
        with open(path, "rb") as f:
            model = pickle.load(f)

        # 缓存序列化模型(大模型使用压缩)
        await self.cache.setex(
            key,
            86400,  # 1天
            pickle.dumps(model)
        )

        return model

    async def cache_prediction(
        self,
        model_id: str,
        input_hash: str,
        prediction: Any,
        ttl: int = 3600
    ):
        """缓存预测结果。"""
        key = f"pred:{model_id}:{input_hash}"
        await self.cache.setex(key, ttl, json.dumps(prediction))

    async def get_cached_prediction(
        self,
        model_id: str,
        input_data: Any
    ) -> Optional[Any]:
        """如果可用,检索缓存预测。"""
        input_hash = hashlib.md5(
            json.dumps(input_data, sort_keys=True).encode()
        ).hexdigest()

        key = f"pred:{model_id}:{input_hash}"
        cached = await self.cache.get(key)

        return json.loads(cached) if cached else None

# 机器学习管道的特征缓存
class FeatureCache:
    def __init__(self, cache_client):
        self.cache = cache_client

    async def get_features(
        self,
        entity_id: str,
        feature_names: List[str],
        compute_fn: Callable
    ) -> Dict[str, Any]:
        """从缓存获取或计算特征。"""

        # 尝试从缓存获取
        keys = [f"feature:{entity_id}:{name}" for name in feature_names]
        cached_values = await self.cache.mget(keys)

        features = {}
        missing = []

        for name, value in zip(feature_names, cached_values):
            if value:
                features[name] = json.loads(value)
            else:
                missing.append(name)

        # 计算缺失特征
        if missing:
            computed = await compute_fn(entity_id, missing)

            # 缓存计算的特征
            for name, value in computed.items():
                key = f"feature:{entity_id}:{name}"
                await self.cache.setex(key, 3600, json.dumps(value))
                features[name] = value

        return features

# 向量相似性的嵌入缓存
class EmbeddingCache:
    def __init__(self, cache_client):
        self.cache = cache_client

    async def get_embedding(
        self,
        text: str,
        model: str,
        embed_fn: Callable
    ) -> np.ndarray:
        """获取或计算文本嵌入。"""

        text_hash = hashlib.md5(text.encode()).hexdigest()
        key = f"embed:{model}:{text_hash}"

        cached = await self.cache.get(key)
        if cached:
            return np.frombuffer(cached, dtype=np.float32)

        embedding = await embed_fn(text)

        # 存储为二进制
        await self.cache.setex(
            key,
            86400 * 7,  # 1周
            embedding.tobytes()
        )

        return embedding

10. 本地内存缓存

from collections import OrderedDict
from threading import RLock
from typing import Optional
import time

class LRUCache:
    """带TTL支持的线程安全LRU缓存。"""

    def __init__(self, capacity: int = 1000, default_ttl: Optional[int] = None):
        self.cache = OrderedDict()
        self.capacity = capacity
        self.default_ttl = default_ttl
        self.lock = RLock()

    def get(self, key: str) -> Optional[Any]:
        with self.lock:
            if key not in self.cache:
                return None

            value, expiry = self.cache[key]

            # 检查过期
            if expiry and time.time() > expiry:
                del self.cache[key]
                return None

            # 移动到末尾(最近使用)
            self.cache.move_to_end(key)
            return value

    def put(self, key: str, value: Any, ttl: Optional[int] = None):
        with self.lock:
            if key in self.cache:
                self.cache.move_to_end(key)

            ttl = ttl or self.default_ttl
            expiry = time.time() + ttl if ttl else None
            self.cache[key] = (value, expiry)

            if len(self.cache) > self.capacity:
                # 移除最旧(最近最少使用)
                self.cache.popitem(last=False)

    def invalidate(self, key: str):
        with self.lock:
            self.cache.pop(key, None)

    def clear(self):
        with self.lock:
            self.cache.clear()

# 函数记忆化装饰器
from functools import wraps

def memoize(ttl: Optional[int] = None, maxsize: int = 128):
    """带TTL的记忆化装饰器。"""

    cache = LRUCache(capacity=maxsize, default_ttl=ttl)

    def decorator(func: Callable):
        @wraps(func)
        def wrapper(*args, **kwargs):
            # 从参数创建缓存键
            key = f"{func.__name__}:{args}:{sorted(kwargs.items())}"

            result = cache.get(key)
            if result is not None:
                return result

            result = func(*args, **kwargs)
            cache.put(key, result)
            return result

        wrapper.cache_clear = cache.clear
        wrapper.cache_info = lambda: {
            "size": len(cache.cache),
            "maxsize": cache.capacity
        }

        return wrapper

    return decorator

# 使用
@memoize(ttl=300, maxsize=1000)
def expensive_computation(x: int, y: int) -> int:
    return x ** y

最佳实践

  1. 缓存重要内容:专注于频繁访问、计算成本高的数据。

  2. 设置适当TTL:平衡新鲜度和性能。使用抖动防止惊群效应。

  3. 优雅处理缓存故障:缓存应增强而非必需操作。

  4. 监控缓存性能:跟踪命中率、延迟和内存使用。

  5. 使用适当数据结构:对象用哈希,排名用有序集合,队列用列表。

  6. 实现正确失效:对于关键数据,优先事件驱动失效而非仅TTL。

  7. 预防击穿:使用锁定、早期重新计算或陈旧数据更新。

  8. 适当调整缓存大小:监控淘汰率并相应调整大小。

  9. 分层缓存:浏览器缓存 → CDN → Redis → 数据库以获得最佳性能。

  10. 安全考虑:未经加密绝不缓存敏感数据(密码、令牌、PII)。

示例

完整缓存层

class CacheLayer:
    """生产就绪的缓存层,具有多种策略。"""

    def __init__(self, redis_url: str, default_ttl: int = 3600):
        self.redis = RedisCache(redis_url)
        self.stampede = StampedeProtection(self.redis.redis)
        self.ttl_manager = TTLManager(default_ttl)
        self.invalidator = CacheInvalidator(self.redis.redis)

    async def get_user(self, user_id: str) -> User:
        return await self.stampede.get_with_lock(
            f"user:{user_id}",
            lambda: self.database.get_user(user_id),
            ttl=self.ttl_manager.get_tiered_ttl("user_profile")
        )

    async def update_user(self, user_id: str, data: dict):
        await self.database.update_user(user_id, data)
        await self.invalidator.invalidate_by_tag(f"user:{user_id}")

    async def get_product_listing(self, category: str) -> List[Product]:
        return await self.stampede.get_stale_while_revalidate(
            f"products:{category}",
            lambda: self.database.get_products(category),
            ttl=self.ttl_manager.get_tiered_ttl("product_catalog"),
            stale_ttl=60
        )