name: 并发 description: 全面的并发和并行模式,用于多线程和异步编程。在实现异步/等待、并行处理、线程安全、工作池或调试竞态条件和死锁时使用。触发词:异步、等待、线程、互斥锁、锁、信号量、通道、Actor、并行、并发、竞态条件、死锁、活锁、原子操作、Future、Promise、Tokio、Asyncio、Goroutine、Spawn、Arc、Mutex、RwLock、MPSC、Select、Join、工作池、队列、同步、临界区、上下文切换。
并发
概述
并发使程序能够高效处理多个任务。此技能涵盖 Rust(Tokio)、Python(Asyncio)、TypeScript(Promise)和 Go(Goroutine)的异步/等待模式。包括并行策略、竞态条件预防、死锁处理、线程安全模式、基于通道的通信和工作队列实现。
代理专家专长
在实现并发时,委托给适当的专家:
- senior-software-engineer(Opus) - 并发系统的架构决策,选择线程模型,设计消息传递与共享状态架构
- software-engineer(Sonnet) - 遵循既定模式实现并发代码,编写异步函数、工作池、速率限制器
- senior-software-engineer(Opus) - 识别竞态条件、时间检查时间使用漏洞,审查死锁的锁顺序
- senior-software-engineer(Opus) - 分布式系统并发、一致性模型、分布式锁、Saga模式
指令
1. Rust Async/Await with Tokio
use tokio::sync::{Mutex, RwLock, Semaphore, mpsc};
use tokio::time::{sleep, Duration, timeout};
use std::sync::Arc;
use futures::future::join_all;
// 基本异步函数
async fn fetch_data(url: &str) -> Result<String, reqwest::Error> {
let response = reqwest::get(url).await?;
response.text().await
}
// 使用 join_all 进行并发执行
async fn fetch_all(urls: Vec<String>) -> Vec<Result<String, reqwest::Error>> {
let tasks: Vec<_> = urls.into_iter()
.map(|url| tokio::spawn(async move { fetch_data(&url).await }))
.collect();
join_all(tasks).await
.into_iter()
.map(|r| r.unwrap())
.collect()
}
// 超时处理
async fn fetch_with_timeout(url: &str) -> Result<String, Box<dyn std::error::Error>> {
match timeout(Duration::from_secs(5), fetch_data(url)).await {
Ok(result) => Ok(result?),
Err(_) => Err(format!("Request to {} timed out", url).into()),
}
}
// 使用信号量进行速率限制
async fn fetch_with_rate_limit(
urls: Vec<String>,
max_concurrent: usize,
) -> Vec<Result<String, reqwest::Error>> {
let semaphore = Arc::new(Semaphore::new(max_concurrent));
let tasks: Vec<_> = urls.into_iter()
.map(|url| {
let sem = semaphore.clone();
tokio::spawn(async move {
let _permit = sem.acquire().await.unwrap();
fetch_data(&url).await
})
})
.collect();
join_all(tasks).await
.into_iter()
.map(|r| r.unwrap())
.collect()
}
// 基于通道的工作模式
async fn worker_pool_example() {
let (tx, mut rx) = mpsc::channel::<String>(100);
// 生成工作者
for i in 0..4 {
let mut worker_rx = rx.clone();
tokio::spawn(async move {
while let Some(url) = worker_rx.recv().await {
println!("Worker {} processing {}", i, url);
let _ = fetch_data(&url).await;
}
});
}
// 发送工作
for url in vec!["https://example.com"; 20] {
tx.send(url.to_string()).await.unwrap();
}
}
// 使用 Arc<Mutex<T>> 的共享状态
#[derive(Clone)]
struct SharedCache {
data: Arc<Mutex<std::collections::HashMap<String, String>>>,
}
impl SharedCache {
async fn get_or_insert(&self, key: String, value: String) -> String {
let mut cache = self.data.lock().await;
cache.entry(key).or_insert(value).clone()
}
}
// 用于读密集型工作负载的 Arc<RwLock<T>>
struct ReadHeavyCache {
data: Arc<RwLock<std::collections::HashMap<String, String>>>,
}
impl ReadHeavyCache {
async fn get(&self, key: &str) -> Option<String> {
let cache = self.data.read().await;
cache.get(key).cloned()
}
async fn insert(&self, key: String, value: String) {
let mut cache = self.data.write().await;
cache.insert(key, value);
}
}
// 使用 Select 竞赛多个 Future
use tokio::select;
async fn fetch_from_fastest(urls: Vec<String>) -> Option<String> {
let mut tasks = urls.into_iter()
.map(|url| Box::pin(fetch_data(&url)))
.collect::<Vec<_>>();
if tasks.is_empty() {
return None;
}
loop {
select! {
result = tasks[0], if !tasks.is_empty() => {
if result.is_ok() {
return result.ok();
}
tasks.remove(0);
}
else => break,
}
}
None
}
2. Python 异步模式
import asyncio
from typing import List, TypeVar, Coroutine, Any
T = TypeVar('T')
# 基本异步函数
async def fetch_data(url: str) -> dict:
async with aiohttp.ClientSession() as session:
async with session.get(url) as response:
return await response.json()
# 使用 gather 进行并发执行
async def fetch_all(urls: List[str]) -> List[dict]:
tasks = [fetch_data(url) for url in urls]
return await asyncio.gather(*tasks, return_exceptions=True)
# 超时处理
async def fetch_with_timeout(url: str, timeout: float = 5.0) -> dict:
try:
return await asyncio.wait_for(fetch_data(url), timeout=timeout)
except asyncio.TimeoutError:
raise TimeoutError(f"Request to {url} timed out after {timeout}s")
# 使用信号量进行速率限制
async def fetch_with_rate_limit(urls: List[str], max_concurrent: int = 10) -> List[dict]:
semaphore = asyncio.Semaphore(max_concurrent)
async def limited_fetch(url: str) -> dict:
async with semaphore:
return await fetch_data(url)
return await asyncio.gather(*[limited_fetch(url) for url in urls])
# 异步上下文管理器
class AsyncDatabaseConnection:
async def __aenter__(self):
self.conn = await asyncpg.connect(...)
return self.conn
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.conn.close()
# 异步迭代器
class AsyncPaginator:
def __init__(self, fetch_page):
self.fetch_page = fetch_page
self.page = 0
self.done = False
def __aiter__(self):
return self
async def __anext__(self):
if self.done:
raise StopAsyncIteration
result = await self.fetch_page(self.page)
if not result:
self.done = True
raise StopAsyncIteration
self.page += 1
return result
3. TypeScript 异步模式
// 使用 Promise.all 进行并发执行
async function fetchAll<T>(urls: string[]): Promise<T[]> {
return Promise.all(urls.map((url) => fetch(url).then((r) => r.json())));
}
// 使用 Promise.allSettled 进行容错
async function fetchAllSafe<T>(urls: string[]): Promise<Array<T | Error>> {
const results = await Promise.allSettled(
urls.map((url) => fetch(url).then((r) => r.json())),
);
return results.map((result) =>
result.status === "fulfilled" ? result.value : new Error(result.reason),
);
}
// 速率限制的并发执行
async function fetchWithConcurrencyLimit<T>(
items: string[],
fn: (item: string) => Promise<T>,
limit: number,
): Promise<T[]> {
const results: T[] = [];
const executing: Promise<void>[] = [];
for (const item of items) {
const p = fn(item).then((result) => {
results.push(result);
});
executing.push(p);
if (executing.length >= limit) {
await Promise.race(executing);
executing.splice(
executing.findIndex((e) => e === p),
1,
);
}
}
await Promise.all(executing);
return results;
}
// 异步队列
class AsyncQueue<T> {
private queue: T[] = [];
private resolvers: Array<(value: T) => void> = [];
async enqueue(item: T): Promise<void> {
if (this.resolvers.length > 0) {
const resolve = this.resolvers.shift()!;
resolve(item);
} else {
this.queue.push(item);
}
}
async dequeue(): Promise<T> {
if (this.queue.length > 0) {
return this.queue.shift()!;
}
return new Promise((resolve) => this.resolvers.push(resolve));
}
}
4. Go 并发模式
package main
import (
"context"
"fmt"
"sync"
"time"
)
// 带通道的基本 Goroutine
func fetchData(url string, ch chan<- string) {
// 模拟获取
time.Sleep(100 * time.Millisecond)
ch <- fmt.Sprintf("Data from %s", url)
}
// 扇出模式(并发工作者)
func fetchAll(urls []string) []string {
ch := make(chan string, len(urls))
for _, url := range urls {
go fetchData(url, ch)
}
results := make([]string, 0, len(urls))
for i := 0; i < len(urls); i++ {
results = append(results, <-ch)
}
return results
}
// 使用 WaitGroup 进行同步
func fetchAllWithWaitGroup(urls []string) []string {
var wg sync.WaitGroup
results := make([]string, len(urls))
for i, url := range urls {
wg.Add(1)
go func(idx int, u string) {
defer wg.Done()
results[idx] = fmt.Sprintf("Data from %s", u)
}(i, url)
}
wg.Wait()
return results
}
// 使用 Context 进行取消
func fetchWithTimeout(ctx context.Context, url string) (string, error) {
ch := make(chan string, 1)
go func() {
time.Sleep(100 * time.Millisecond)
ch <- fmt.Sprintf("Data from %s", url)
}()
select {
case result := <-ch:
return result, nil
case <-ctx.Done():
return "", ctx.Err()
}
}
// 带缓冲通道的工作池
func workerPool(jobs <-chan string, results chan<- string, numWorkers int) {
var wg sync.WaitGroup
for i := 0; i < numWorkers; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
for job := range jobs {
results <- fmt.Sprintf("Worker %d processed %s", id, job)
}
}(i)
}
wg.Wait()
close(results)
}
// 使用 Ticker 进行速率限制
func rateLimit(urls []string, requestsPerSecond int) {
ticker := time.NewTicker(time.Second / time.Duration(requestsPerSecond))
defer ticker.Stop()
for _, url := range urls {
<-ticker.C
go fetchData(url, nil)
}
}
// 使用 Select 多路复用通道
func fanIn(ch1, ch2 <-chan string) <-chan string {
out := make(chan string)
go func() {
defer close(out)
for {
select {
case val, ok := <-ch1:
if !ok {
ch1 = nil
} else {
out <- val
}
case val, ok := <-ch2:
if !ok {
ch2 = nil
} else {
out <- val
}
}
if ch1 == nil && ch2 == nil {
return
}
}
}()
return out
}
// 使用 Mutex 进行共享状态
type SafeCounter struct {
mu sync.Mutex
count int
}
func (c *SafeCounter) Inc() {
c.mu.Lock()
defer c.mu.Unlock()
c.count++
}
func (c *SafeCounter) Value() int {
c.mu.Lock()
defer c.mu.Unlock()
return c.count
}
// 用于读密集型工作负载的 RWMutex
type Cache struct {
mu sync.RWMutex
data map[string]string
}
func (c *Cache) Get(key string) (string, bool) {
c.mu.RLock()
defer c.mu.RUnlock()
val, ok := c.data[key]
return val, ok
}
func (c *Cache) Set(key, value string) {
c.mu.Lock()
defer c.mu.Unlock()
c.data[key] = value
}
// 使用 Once 进行一次性初始化
var (
instance *Singleton
once sync.Once
)
type Singleton struct {
value string
}
func GetInstance() *Singleton {
once.Do(func() {
instance = &Singleton{value: "initialized"}
})
return instance
}
5. 并行与并发
import asyncio
import multiprocessing
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
# 并发:I/O 密集型任务(使用异步或线程)
async def io_bound_concurrent():
"""用于网络调用、文件 I/O、数据库查询。"""
async with aiohttp.ClientSession() as session:
tasks = [session.get(url) for url in urls]
return await asyncio.gather(*tasks)
def io_bound_threaded(urls: List[str]):
"""当异步不可用时,使用线程进行 I/O。"""
with ThreadPoolExecutor(max_workers=10) as executor:
return list(executor.map(requests.get, urls))
# 并行:CPU 密集型任务(使用进程)
def cpu_bound_parallel(data: List[int]) -> List[int]:
"""用于重型计算 - 绕过 GIL。"""
with ProcessPoolExecutor() as executor:
return list(executor.map(heavy_computation, data))
# 混合:带 I/O 的 CPU 工作
async def hybrid_processing(items: List[dict]):
"""结合异步 I/O 与并行 CPU 处理。"""
loop = asyncio.get_event_loop()
# 并发获取数据
raw_data = await asyncio.gather(*[fetch(item) for item in items])
# 并行处理 CPU 密集型工作
with ProcessPoolExecutor() as executor:
processed = await loop.run_in_executor(
executor,
process_batch,
raw_data
)
return processed
6. 竞态条件与预防
import threading
import asyncio
from contextlib import contextmanager
# 带锁的线程安全计数器
class ThreadSafeCounter:
def __init__(self):
self._value = 0
self._lock = threading.Lock()
def increment(self):
with self._lock:
self._value += 1
return self._value
@property
def value(self):
with self._lock:
return self._value
# 用于优化并发访问的读写锁
class ReadWriteLock:
def __init__(self):
self._read_ready = threading.Condition(threading.Lock())
self._readers = 0
@contextmanager
def read_lock(self):
with self._read_ready:
self._readers += 1
try:
yield
finally:
with self._read_ready:
self._readers -= 1
if self._readers == 0:
self._read_ready.notify_all()
@contextmanager
def write_lock(self):
with self._read_ready:
while self._readers > 0:
self._read_ready.wait()
yield
# 用于异步代码的异步锁
class AsyncSafeCache:
def __init__(self):
self._cache = {}
self._lock = asyncio.Lock()
async def get_or_set(self, key: str, factory):
async with self._lock:
if key not in self._cache:
self._cache[key] = await factory()
return self._cache[key]
# 用于无锁操作的比较并交换
import atomics # 或使用线程原语
class LockFreeCounter:
def __init__(self):
self._value = atomics.atomic(width=8, atype=atomics.INT)
def increment(self):
while True:
current = self._value.load()
if self._value.cmpxchg_weak(current, current + 1):
return current + 1
7. 死锁检测与预防
import threading
from collections import defaultdict
from typing import Dict, Set
import time
# 通过锁排序预防死锁
class OrderedLockManager:
"""通过强制锁获取顺序预防死锁。"""
def __init__(self):
self._lock_order: Dict[str, int] = {}
self._next_order = 0
self._thread_locks: Dict[int, Set[str]] = defaultdict(set)
self._meta_lock = threading.Lock()
def register_lock(self, name: str) -> threading.Lock:
with self._meta_lock:
if name not in self._lock_order:
self._lock_order[name] = self._next_order
self._next_order += 1
return threading.Lock()
@contextmanager
def acquire(self, lock: threading.Lock, name: str):
thread_id = threading.current_thread().ident
# 检查锁顺序
held_locks = self._thread_locks[thread_id]
for held_name in held_locks:
if self._lock_order[name] < self._lock_order[held_name]:
raise RuntimeError(
f"Lock ordering violation: {name} < {held_name}"
)
lock.acquire()
self._thread_locks[thread_id].add(name)
try:
yield
finally:
self._thread_locks[thread_id].discard(name)
lock.release()
# 基于超时的死锁检测
class TimeoutLock:
def __init__(self, timeout: float = 5.0):
self._lock = threading.Lock()
self._timeout = timeout
def acquire(self):
acquired = self._lock.acquire(timeout=self._timeout)
if not acquired:
raise DeadlockError(
f"Failed to acquire lock within {self._timeout}s - possible deadlock"
)
return True
def release(self):
self._lock.release()
def __enter__(self):
self.acquire()
return self
def __exit__(self, *args):
self.release()
# 使用等待图进行死锁检测
class DeadlockDetector:
def __init__(self):
self._wait_for: Dict[int, int] = {} # 线程 -> 它正在等待的线程
self._lock = threading.Lock()
def register_wait(self, waiting_thread: int, holding_thread: int):
with self._lock:
self._wait_for[waiting_thread] = holding_thread
if self._has_cycle():
raise DeadlockError("Deadlock detected in wait-for graph")
def unregister_wait(self, thread: int):
with self._lock:
self._wait_for.pop(thread, None)
def _has_cycle(self) -> bool:
visited = set()
rec_stack = set()
def dfs(node):
visited.add(node)
rec_stack.add(node)
next_node = self._wait_for.get(node)
if next_node:
if next_node not in visited:
if dfs(next_node):
return True
elif next_node in rec_stack:
return True
rec_stack.remove(node)
return False
for node in self._wait_for:
if node not in visited:
if dfs(node):
return True
return False
8. 线程安全模式
import threading
from functools import wraps
from typing import TypeVar, Generic
T = TypeVar('T')
# 线程局部存储
class RequestContext:
_local = threading.local()
@classmethod
def set_user(cls, user_id: str):
cls._local.user_id = user_id
@classmethod
def get_user(cls) -> str:
return getattr(cls._local, 'user_id', None)
# 用于线程安全的不可变数据
from dataclasses import dataclass
from typing import Tuple
@dataclass(frozen=True)
class ImmutableConfig:
host: str
port: int
options: Tuple[str, ...] # 使用元组代替列表
# 写时复制模式
class CopyOnWriteList(Generic[T]):
def __init__(self):
self._data: Tuple[T, ...] = ()
self._lock = threading.Lock()
def append(self, item: T):
with self._lock:
self._data = (*self._data, item)
def __iter__(self):
# 快照迭代 - 无需锁安全
return iter(self._data)
# 线程安全单例
class Singleton:
_instance = None
_lock = threading.Lock()
def __new__(cls):
if cls._instance is None:
with cls._lock:
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
# 同步装饰器
def synchronized(lock: threading.Lock = None):
def decorator(func):
nonlocal lock
if lock is None:
lock = threading.Lock()
@wraps(func)
def wrapper(*args, **kwargs):
with lock:
return func(*args, **kwargs)
return wrapper
return decorator
9. 工作队列与工作池
import asyncio
import queue
import threading
from typing import Callable, Any, List
from dataclasses import dataclass
from concurrent.futures import Future
@dataclass
class Job:
func: Callable
args: tuple
kwargs: dict
future: Future
# 基于线程的工作池
class ThreadWorkerPool:
def __init__(self, num_workers: int = 4):
self._queue = queue.Queue()
self._workers: List[threading.Thread] = []
self._shutdown = False
for _ in range(num_workers):
worker = threading.Thread(target=self._worker_loop, daemon=True)
worker.start()
self._workers.append(worker)
def _worker_loop(self):
while not self._shutdown:
try:
job = self._queue.get(timeout=1)
try:
result = job.func(*job.args, **job.kwargs)
job.future.set_result(result)
except Exception as e:
job.future.set_exception(e)
except queue.Empty:
continue
def submit(self, func: Callable, *args, **kwargs) -> Future:
future = Future()
job = Job(func, args, kwargs, future)
self._queue.put(job)
return future
def shutdown(self, wait: bool = True):
self._shutdown = True
if wait:
for worker in self._workers:
worker.join()
# 异步工作池
class AsyncWorkerPool:
def __init__(self, num_workers: int = 10):
self._queue: asyncio.Queue = asyncio.Queue()
self._num_workers = num_workers
self._workers: List[asyncio.Task] = []
async def start(self):
for _ in range(self._num_workers):
task = asyncio.create_task(self._worker_loop())
self._workers.append(task)
async def _worker_loop(self):
while True:
job = await self._queue.get()
if job is None: # 关闭信号
break
func, args, kwargs, future = job
try:
if asyncio.iscoroutinefunction(func):
result = await func(*args, **kwargs)
else:
result = func(*args, **kwargs)
future.set_result(result)
except Exception as e:
future.set_exception(e)
finally:
self._queue.task_done()
async def submit(self, func: Callable, *args, **kwargs) -> Any:
future = asyncio.Future()
await self._queue.put((func, args, kwargs, future))
return await future
async def shutdown(self):
for _ in self._workers:
await self._queue.put(None)
await asyncio.gather(*self._workers)
# 优先级队列工作池
class PriorityWorkerPool:
def __init__(self, num_workers: int = 4):
self._queue = queue.PriorityQueue()
self._workers: List[threading.Thread] = []
self._shutdown = False
for _ in range(num_workers):
worker = threading.Thread(target=self._worker_loop, daemon=True)
worker.start()
self._workers.append(worker)
def _worker_loop(self):
while not self._shutdown:
try:
priority, job = self._queue.get(timeout=1)
try:
result = job.func(*job.args, **job.kwargs)
job.future.set_result(result)
except Exception as e:
job.future.set_exception(e)
except queue.Empty:
continue
def submit(self, func: Callable, *args, priority: int = 0, **kwargs) -> Future:
future = Future()
job = Job(func, args, kwargs, future)
self._queue.put((priority, job))
return future
最佳实践
-
优先使用异步进行 I/O:用于网络和文件 I/O 操作时使用异步/等待。
-
使用进程进行 CPU 工作:使用 ProcessPoolExecutor 绕过 GIL 进行 CPU 密集型任务。
-
最小化共享状态:优先使用消息传递而非共享内存。
-
锁排序:始终以一致顺序获取锁以预防死锁。
-
保持临界区小:持有锁的时间尽可能短。
-
使用高级抽象:优先使用队列、Future 和异步模式而非原始锁。
-
测试竞态条件:使用 ThreadSanitizer 等工具和压力测试。
-
记录线程安全性:清晰记录哪些方法是线程安全的。
示例
完整的带速率限制的异步网络爬虫
import asyncio
import aiohttp
from dataclasses import dataclass
from typing import List, Optional
@dataclass
class ScrapeResult:
url: str
status: int
content: Optional[str]
error: Optional[str] = None
class AsyncScraper:
def __init__(
self,
max_concurrent: int = 10,
requests_per_second: float = 5.0
):
self.semaphore = asyncio.Semaphore(max_concurrent)
self.rate_limit = 1.0 / requests_per_second
self.last_request_time = 0
self._lock = asyncio.Lock()
async def _rate_limit(self):
async with self._lock:
now = asyncio.get_event_loop().time()
wait_time = self.last_request_time + self.rate_limit - now
if wait_time > 0:
await asyncio.sleep(wait_time)
self.last_request_time = asyncio.get_event_loop().time()
async def scrape_url(
self,
session: aiohttp.ClientSession,
url: str
) -> ScrapeResult:
async with self.semaphore:
await self._rate_limit()
try:
async with session.get(url, timeout=10) as response:
content = await response.text()
return ScrapeResult(
url=url,
status=response.status,
content=content
)
except Exception as e:
return ScrapeResult(
url=url,
status=0,
content=None,
error=str(e)
)
async def scrape_all(self, urls: List[str]) -> List[ScrapeResult]:
async with aiohttp.ClientSession() as session:
tasks = [self.scrape_url(session, url) for url in urls]
return await asyncio.gather(*tasks)
# 使用
async def main():
scraper = AsyncScraper(max_concurrent=5, requests_per_second=2.0)
urls = ["https://example.com"] * 20
results = await scraper.scrape_all(urls)
for result in results:
if result.error:
print(f"Failed: {result.url} - {result.error}")
else:
print(f"Success: {result.url} - {result.status}")
asyncio.run(main())
数据管道并行性
用于 ETL、流处理和批处理数据管道的并发阶段模式。
import asyncio
from typing import AsyncIterator, Callable, TypeVar, List
from dataclasses import dataclass
import queue
import threading
T = TypeVar('T')
U = TypeVar('U')
# 带背压的异步管道
class AsyncPipeline:
"""
多阶段异步管道,带有限队列以实现背压。
每个阶段最多以工作者限制并发处理项目。
"""
def __init__(self, max_queue_size: int = 100):
self.max_queue_size = max_queue_size
async def stage(
self,
input_iter: AsyncIterator[T],
transform: Callable[[T], U],
workers: int = 4
) -> AsyncIterator[U]:
"""带并发工作者的单管道阶段。"""
queue_in = asyncio.Queue(maxsize=self.max_queue_size)
queue_out = asyncio.Queue(maxsize=self.max_queue_size)
# 生产者:填充输入队列
async def producer():
async for item in input_iter:
await queue_in.put(item)
for _ in range(workers):
await queue_in.put(None) # 工作者哨兵
# 工作者:转换项目
async def worker():
while True:
item = await queue_in.get()
if item is None:
break
try:
if asyncio.iscoroutinefunction(transform):
result = await transform(item)
else:
result = transform(item)
await queue_out.put(result)
except Exception as e:
await queue_out.put(e)
# 消费者:产出结果
async def consumer():
processed = 0
while processed < workers:
result = await queue_out.get()
if result is None:
processed += 1
continue
if isinstance(result, Exception):
raise result
yield result
# 启动生产者和工作者
asyncio.create_task(producer())
worker_tasks = [asyncio.create_task(worker()) for _ in range(workers)]
# 从消费者产出
async for item in consumer():
yield item
# 清理
await asyncio.gather(*worker_tasks)
# 用于 CPU 密集型工作的基于线程的管道
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
class ParallelPipeline:
"""使用进程池进行 CPU 密集型阶段的管道。"""
@staticmethod
def map_stage(
items: List[T],
transform: Callable[[T], U],
workers: int = None
) -> List[U]:
"""使用进程的并行映射阶段。"""
with ProcessPoolExecutor(max_workers=workers) as executor:
return list(executor.map(transform, items))
@staticmethod
def filter_stage(
items: List[T],
predicate: Callable[[T], bool],
workers: int = None
) -> List[T]:
"""并行过滤阶段。"""
with ThreadPoolExecutor(max_workers=workers) as executor:
results = executor.map(lambda x: (x, predicate(x)), items)
return [item for item, keep in results if keep]
@staticmethod
def reduce_stage(
items: List[T],
reducer: Callable[[U, T], U],
initial: U,
chunk_size: int = 1000
) -> U:
"""带分块的并行归约。"""
def reduce_chunk(chunk):
result = initial
for item in chunk:
result = reducer(result, item)
return result
chunks = [items[i:i+chunk_size] for i in range(0, len(items), chunk_size)]
with ProcessPoolExecutor() as executor:
partial_results = list(executor.map(reduce_chunk, chunks))
# 部分结果的最终归约
final = initial
for partial in partial_results:
final = reducer(final, partial)
return final
# 带批处理和超时的流式管道
class StreamingPipeline:
"""使用批处理和超时处理无界流。"""
@staticmethod
async def batch_stream(
stream: AsyncIterator[T],
batch_size: int = 100,
timeout: float = 1.0
) -> AsyncIterator[List[T]]:
"""按大小或超时将项目收集到批次中。"""
batch = []
deadline = asyncio.get_event_loop().time() + timeout
async for item in stream:
batch.append(item)
if len(batch) >= batch_size:
yield batch
batch = []
deadline = asyncio.get_event_loop().time() + timeout
elif asyncio.get_event_loop().time() >= deadline:
if batch:
yield batch
batch = []
deadline = asyncio.get_event_loop().time() + timeout
if batch:
yield batch
@staticmethod
async def parallel_batch_process(
batched_stream: AsyncIterator[List[T]],
process_batch: Callable[[List[T]], List[U]],
max_concurrent: int = 4
) -> AsyncIterator[U]:
"""按并发限制并行处理批次。"""
semaphore = asyncio.Semaphore(max_concurrent)
async def process_with_limit(batch):
async with semaphore:
return await asyncio.to_thread(process_batch, batch)
pending = set()
async for batch in batched_stream:
task = asyncio.create_task(process_with_limit(batch))
pending.add(task)
if len(pending) >= max_concurrent:
done, pending = await asyncio.wait(
pending,
return_when=asyncio.FIRST_COMPLETED
)
for task in done:
results = await task
for result in results:
yield result
# 排空剩余
while pending:
done, pending = await asyncio.wait(pending)
for task in done:
results = await task
for result in results:
yield result
# 完整的 ETL 示例
@dataclass
class Record:
id: int
value: str
class ETLPipeline:
"""带提取、转换、加载的完整 ETL 管道。"""
async def extract(self) -> AsyncIterator[Record]:
"""模拟从源提取数据。"""
for i in range(1000):
await asyncio.sleep(0.001) # 模拟 I/O
yield Record(id=i, value=f"raw_{i}")
def transform(self, record: Record) -> Record:
"""CPU 密集型转换。"""
import hashlib
transformed = hashlib.sha256(record.value.encode()).hexdigest()
return Record(id=record.id, value=transformed)
async def load_batch(self, records: List[Record]):
"""批量加载到目的地。"""
await asyncio.sleep(0.1) # 模拟批量写入
print(f"Loaded batch of {len(records)} records")
async def run(self):
"""执行完整管道。"""
pipeline = AsyncPipeline()
# 阶段 1:提取
extracted = self.extract()
# 阶段 2:转换(并发)
transformed = pipeline.stage(extracted, self.transform, workers=8)
# 阶段 3:批处理和加载
batched = StreamingPipeline.batch_stream(transformed, batch_size=50)
async for batch in batched:
await self.load_batch(batch)
# 使用
async def main():
etl = ETLPipeline()
await etl.run()
asyncio.run(main())
Rust Data Pipeline with Rayon
use rayon::prelude::*;
use std::sync::{Arc, Mutex};
// 并行映射-归约管道
fn parallel_pipeline(data: Vec<i32>) -> i32 {
data.par_iter()
.map(|x| x * x) // 并行映射
.filter(|x| x % 2 == 0) // 并行过滤
.sum() // 并行归约
}
// 带中间集合的管道
struct Pipeline<T> {
data: Vec<T>,
}
impl<T: Send + Sync> Pipeline<T> {
fn new(data: Vec<T>) -> Self {
Self { data }
}
fn map<U, F>(self, f: F) -> Pipeline<U>
where
U: Send,
F: Fn(T) -> U + Send + Sync,
{
let data = self.data.into_par_iter().map(f).collect();
Pipeline { data }
}
fn filter<F>(self, f: F) -> Pipeline<T>
where
F: Fn(&T) -> bool + Send + Sync,
{
let data = self.data.into_par_iter().filter(f).collect();
Pipeline { data }
}
fn collect(self) -> Vec<T> {
self.data
}
}
// 异步流处理
use tokio::sync::mpsc;
use tokio_stream::{Stream, StreamExt};
async fn process_stream<T, U, F>(
mut stream: impl Stream<Item = T> + Unpin,
transform: F,
parallelism: usize,
) -> Vec<U>
where
T: Send + 'static,
U: Send + 'static,
F: Fn(T) -> U + Send + Sync + Clone + 'static,
{
let (tx, mut rx) = mpsc::channel(100);
let processor = tokio::spawn(async move {
let mut results = Vec::new();
while let Some(result) = rx.recv().await {
results.push(result);
}
results
});
stream
.for_each_concurrent(parallelism, |item| {
let tx = tx.clone();
let transform = transform.clone();
async move {
let result = transform(item);
let _ = tx.send(result).await;
}
})
.await;
drop(tx);
processor.await.unwrap()
}