name: python-jax description: JAX(Just After eXecution)的专家指导 - 具有自动微分、JIT编译、向量化和GPU/TPU加速的高性能数值计算;包括变换(grad、jit、vmap、pmap)、尖锐点、陷阱以及与NumPy的区别 allowed-tools: [“*”]
JAX - 高性能数值计算
概述
JAX 是一个Python库,用于加速器导向的数组计算和程序转换,专为高性能数值计算和大规模机器学习设计。它结合了熟悉的NumPy风格API和强大的函数变换,用于自动微分、编译、向量化和并行化。
核心价值: 编写类似NumPy的Python代码,并通过可组合的函数变换自动获取梯度、GPU/TPU加速、向量化和并行化——无需改变数学表示。
何时使用JAX
使用JAX时:
- 需要自动微分进行优化或机器学习
- 希望通过最小代码更改获得GPU/TPU加速
- 需要高性能数值计算
- 构建自定义梯度算法
- 需要自动向量化或并行化函数
- 研究需要灵活微分
- 希望以函数式编程方式编写数值代码
不要使用时:
- 简单的NumPy操作无需性能需求(开销不合理)
- 重度依赖原地突变(JAX数组不可变)
- 命令式、有状态的代码带副作用
- 需要依赖运行时数据值的控制流(支持有限)
- 使用需要NumPy数组的库(兼容性问题)
核心变换
JAX提供四个基础、可组合的变换:
1. jax.grad - 自动微分
使用反向模式自动微分计算梯度。
import jax
import jax.numpy as jnp
def loss(x):
return jnp.sum(x**2)
# 获取梯度函数
grad_loss = jax.grad(loss)
x = jnp.array([1.0, 2.0, 3.0])
gradient = grad_loss(x)
print(gradient) # [2. 4. 6.]
关键特性:
- 返回计算梯度的函数
- 可组合用于高阶导数
- 适用于复杂嵌套结构(pytrees)
2. jax.jit - 即时编译
编译函数到XLA以实现显著加速。
import jax
@jax.jit
def fast_function(x):
return jnp.sum(x**2 + 3*x + 1)
# 第一次调用:编译 + 执行(较慢)
result = fast_function(jnp.array([1.0, 2.0, 3.0]))
# 后续调用:缓存的编译版本(快得多)
result = fast_function(jnp.array([4.0, 5.0, 6.0]))
性能:
- 典型数值函数加速10-100倍
- 第一次调用较慢(编译开销)
- 相同形状/数据类型时缓存
3. jax.vmap - 自动向量化
自动跨批次维度向量化函数。
import jax
def process_single(x):
"""处理单个示例"""
return jnp.sum(x**2)
# 向量化以处理批次
process_batch = jax.vmap(process_single)
# 自动处理批次
batch = jnp.array([[1, 2], [3, 4], [5, 6]])
results = process_batch(batch)
print(results) # [5 25 61]
好处:
- 消除手动循环编写
- 通常比显式循环快
- 代码更干净、声明性更强
4. jax.pmap - 并行映射
跨多个设备(GPU/TPU)并行化。
import jax
@jax.pmap
def parallel_fn(x):
return x**2
# 在所有可用设备上运行
devices = jax.devices()
x = jnp.arange(len(devices))
results = parallel_fn(x)
用于:
- 多GPU/TPU计算
- 训练中的数据并行
- 大规模模拟
变换组合
JAX变换无缝组合:
# JIT编译函数的梯度
@jax.jit
def fast_loss(x):
return jnp.sum(x**2)
grad_fast_loss = jax.grad(fast_loss)
# 梯度的JIT(同样有效)
fast_grad_loss = jax.jit(jax.grad(fast_loss))
# 向量化梯度
batch_grad_loss = jax.vmap(jax.grad(fast_loss))
# JIT + vmap + grad
fast_batch_grad = jax.jit(jax.vmap(jax.grad(fast_loss)))
顺序影响性能但不影响正确性。
JAX vs NumPy - 关键区别
1. 不可变性
NumPy(可变):
import numpy as np
x = np.array([1, 2, 3])
x[0] = 10 # 原地修改x
print(x) # [10 2 3]
JAX(不可变):
import jax.numpy as jnp
x = jnp.array([1, 2, 3])
# x[0] = 10 # 错误:JAX数组不可变
# 使用函数式更新代替
x = x.at[0].set(10) # 返回新数组
print(x) # [10 2 3]
函数式更新:
# 设置值
x = x.at[0].set(10)
# 增加值
x = x.at[0].add(5)
# 乘以值
x = x.at[0].mul(2)
# 最小/最大
x = x.at[0].min(5)
x = x.at[0].max(5)
# 多索引
x = x.at[0, 1].set(10)
x = x.at[[0, 2]].set(10)
x = x.at[0:3].set(10)
2. 随机数生成
NumPy(全局状态):
import numpy as np
np.random.seed(42)
x = np.random.normal(size=3)
y = np.random.normal(size=3) # 与x不同
JAX(显式密钥):
import jax
key = jax.random.PRNGKey(42)
# 分割密钥以独立随机
key, subkey = jax.random.split(key)
x = jax.random.normal(subkey, shape=(3,))
key, subkey = jax.random.split(key)
y = jax.random.normal(subkey, shape=(3,))
# 并行随机数
keys = jax.random.split(key, num=10)
samples = jax.vmap(lambda k: jax.random.normal(k, shape=(3,)))(keys)
密钥管理模式:
# 创建初始密钥
key = jax.random.PRNGKey(0)
# 使用前始终分割
key, subkey1, subkey2 = jax.random.split(key, 3)
# 使用子密钥进行随机操作
x = jax.random.normal(subkey1, shape=(10,))
y = jax.random.uniform(subkey2, shape=(10,))
# 绝不重用密钥!
3. 64位精度
JAX默认32位以优化性能。
import jax.numpy as jnp
x = jnp.array([1.0])
print(x.dtype) # float32(默认)
# 全局启用64位
from jax import config
config.update("jax_enable_x64", True)
x = jnp.array([1.0])
print(x.dtype) # float64
# 或使用环境变量
# JAX_ENABLE_X64=1 python script.py
4. 越界索引
NumPy(引发错误):
import numpy as np
x = np.array([1, 2, 3])
# x[10] # IndexError
JAX(静默截断):
import jax.numpy as jnp
x = jnp.array([1, 2, 3])
print(x[10]) # 3(返回最后一个元素,无错误!)
# 更新也静默忽略
x = x.at[10].set(99) # 无错误,无效果
⚠️ 警告: 这是未定义行为 - 避免依赖它!
5. 非数组输入
NumPy(接受列表):
import numpy as np
result = np.sum([1, 2, 3]) # 工作
JAX(需要数组):
import jax.numpy as jnp
# result = jnp.sum([1, 2, 3]) # 错误
# 显式转换
result = jnp.sum(jnp.array([1, 2, 3])) # 工作
原因: 防止跟踪期间性能下降。
尖锐点与陷阱
1. 需纯函数
❌ 错误:副作用
counter = 0
@jax.jit
def impure_fn(x):
global counter
counter += 1 # 副作用 - 不可靠
return x**2
# 第一次调用:counter=1(跟踪)
result = impure_fn(2.0)
# 后续调用:counter仍为1(使用缓存跟踪)
result = impure_fn(3.0)
✅ 正确:纯函数
@jax.jit
def pure_fn(x, counter):
return x**2, counter + 1
result, counter = pure_fn(2.0, 0)
result, counter = pure_fn(3.0, counter)
2. 控制流限制
❌ 错误:值依赖控制流
@jax.jit
def conditional(x):
if x > 0: # 错误:x是跟踪器,非具体值
return x**2
else:
return x**3
✅ 正确:使用jax.lax.cond
@jax.jit
def conditional(x):
return jax.lax.cond(
x > 0,
lambda x: x**2, # 真分支
lambda x: x**3 # 假分支
)
循环:
# ❌ 错误:JIT中的Python for循环(编译时展开)
@jax.jit
def loop_bad(x, n):
for i in range(n): # n必须为静态
x = x + 1
return x
# ✅ 正确:使用jax.lax.fori_loop
@jax.jit
def loop_good(x, n):
def body(i, val):
return val + 1
return jax.lax.fori_loop(0, n, body, x)
while循环:
@jax.jit
def while_loop(x):
def cond_fun(val):
return val < 10
def body_fun(val):
return val + 1
return jax.lax.while_loop(cond_fun, body_fun, x)
3. 动态形状
❌ 错误:形状依赖运行时值
@jax.jit
def dynamic_shape(x, mask):
return x[mask] # 错误:编译时输出形状未知
✅ 正确:使用jnp.where进行掩码
@jax.jit
def static_shape(x, mask):
return jnp.where(mask, x, 0) # 形状保持不变
4. 原地更新语义
❌ 错误:依赖共享引用
x = jnp.array([1, 2, 3])
y = x
x = x.at[0].set(10)
print(y[0]) # 仍为1(y未改变)
注意: .at返回新数组;不修改原数组。
5. JIT中的打印调试
❌ 错误:JIT上下文中的print()
@jax.jit
def debug(x):
print(f"x = {x}") # 仅打印一次(跟踪期间)
return x**2
✅ 正确:使用jax.debug.print
@jax.jit
def debug(x):
jax.debug.print("x = {}", x) # 每次调用都打印
return x**2
6. 通过离散操作的梯度
❌ 错误:通过argmax的梯度
def loss(x):
idx = jnp.argmax(x) # 离散 - 无梯度
return x[idx]
# grad_loss = jax.grad(loss) # 梯度处处为零
✅ 正确:使用可微分近似
def loss(x):
# 软argmax(可微分)
weights = jax.nn.softmax(x * temperature)
return jnp.sum(x * weights)
grad_loss = jax.grad(loss)
自动微分
基本梯度
import jax
import jax.numpy as jnp
def f(x):
return jnp.sum(x**2)
grad_f = jax.grad(f)
x = jnp.array([1.0, 2.0, 3.0])
print(grad_f(x)) # [2. 4. 6.]
值与梯度
# 高效计算值和梯度
value_and_grad_f = jax.value_and_grad(f)
value, gradient = value_and_grad_f(x)
print(f"值: {value}, 梯度: {gradient}")
多参数
def f(x, y):
return jnp.sum(x**2 + y**3)
# 梯度相对于第一个参数(默认)
grad_f = jax.grad(f)
print(grad_f(x, y))
# 梯度相对于第二个参数
grad_f_wrt_y = jax.grad(f, argnums=1)
print(grad_f_wrt_y(x, y))
# 相对于两个参数的梯度
grad_f_both = jax.grad(f, argnums=(0, 1))
grad_x, grad_y = grad_f_both(x, y)
辅助数据
def loss_with_aux(x):
loss = jnp.sum(x**2)
aux_data = {'norm': jnp.linalg.norm(x), 'mean': jnp.mean(x)}
return loss, aux_data
# 告知JAX辅助输出
grad_fn = jax.grad(loss_with_aux, has_aux=True)
gradient, aux = grad_fn(x)
print(f"梯度: {gradient}")
print(f"辅助: {aux}")
高阶导数
# 二阶导数(Hessian对角线)
def f(x):
return jnp.sum(x**3)
grad_f = jax.grad(f)
hess_diag_f = jax.grad(lambda x: jnp.sum(grad_f(x) * x))
# 完整Hessian
hessian_f = jax.hessian(f)
x = jnp.array([1.0, 2.0, 3.0])
print(hessian_f(x))
Jacobian
def vector_fn(x):
"""向量到向量函数"""
return jnp.array([x[0]**2, x[1]**3, x[0]*x[1]])
# 前向模式(输入少输出多时高效)
jacfwd = jax.jacfwd(vector_fn)
# 反向模式(输入多输出少时高效)
jacrev = jax.jacrev(vector_fn)
x = jnp.array([2.0, 3.0])
print(jacfwd(x))
print(jacrev(x)) # 相同结果
自定义梯度
@jax.custom_gradient
def f(x):
# 前向传递
result = jnp.exp(x)
# 定义自定义梯度
def grad_fn(g):
# g是下游梯度
# 返回相对于输入的梯度
return g * 2 * result # 自定义:2倍正常梯度
return result, grad_fn
# 像正常一样使用
grad_f = jax.grad(f)
JIT编译
基本用法
import jax
import jax.numpy as jnp
def slow_fn(x):
return jnp.sum(x**2 + 3*x + 1)
fast_fn = jax.jit(slow_fn)
# 或装饰器
@jax.jit
def fast_fn2(x):
return jnp.sum(x**2 + 3*x + 1)
静态参数
@jax.jit
def fn(x, n):
for i in range(n): # n必须为静态
x = x + 1
return x
# 错误:n每次调用变化,触发重新编译
result = fn(x, 5)
result = fn(x, 10) # 重新编译!
# 解决方案:标记为静态
@jax.jit(static_argnums=(1,))
def fn_static(x, n):
for i in range(n):
x = x + 1
return x
# 现在正确工作
result = fn_static(x, 5)
result = fn_static(x, 10) # n=10的单独编译
避免重新编译
# 错误:不同形状触发重新编译
@jax.jit
def process(x):
return jnp.sum(x**2)
x1 = jnp.ones(10) # 编译形状(10,)
x2 = jnp.ones(20) # 重新编译形状(20,)
x3 = jnp.ones(10) # 使用缓存版本
# 正确:一致形状
batch_size = 32
x1 = jnp.ones((batch_size, 10))
x2 = jnp.ones((batch_size, 10)) # 相同形状,缓存
向量化(vmap)
基本批处理
def process_single(x):
"""处理单个示例:标量输入"""
return x**2 + 3*x
# 手动向量化
def process_batch_manual(xs):
return jnp.array([process_single(x) for x in xs])
# 自动向量化
process_batch = jax.vmap(process_single)
batch = jnp.array([1.0, 2.0, 3.0, 4.0])
print(process_batch(batch)) # 更快更干净
批处理矩阵操作
def matrix_vector_product(matrix, vector):
"""单个矩阵-向量乘积"""
return matrix @ vector
# 在向量上批处理
batch_mvp = jax.vmap(matrix_vector_product, in_axes=(None, 0))
A = jnp.ones((3, 3))
vectors = jnp.ones((10, 3)) # 10个向量
results = batch_mvp(A, vectors) # (10, 3)
# 在两者上批处理
batch_both = jax.vmap(matrix_vector_product, in_axes=(0, 0))
matrices = jnp.ones((10, 3, 3))
results = batch_both(matrices, vectors)
嵌套vmap
# 在两个维度上批处理
def fn(x, y):
return x * y
# 在第一个参数上批处理,然后在第二个上
fn_batch = jax.vmap(jax.vmap(fn, in_axes=(None, 0)), in_axes=(0, None))
x = jnp.array([1, 2, 3]) # (3,)
y = jnp.array([10, 20]) # (2,)
result = fn_batch(x, y) # (3, 2)
# 等同于:x[:, None] * y[None, :]
PyTrees - 嵌套结构
JAX适用于嵌套Python容器(pytrees):
import jax
import jax.numpy as jnp
# 数组字典
params = {
'w1': jnp.ones((10, 5)),
'b1': jnp.zeros(5),
'w2': jnp.ones((5, 1)),
'b2': jnp.zeros(1)
}
# 梯度适用于整个pytree
def loss(params, x):
h = x @ params['w1'] + params['b1']
h = jax.nn.relu(h)
out = h @ params['w2'] + params['b2']
return jnp.mean(out**2)
grad_fn = jax.grad(loss)
x = jnp.ones((32, 10))
# 返回相同结构的梯度
grads = grad_fn(params, x)
print(grads.keys()) # dict_keys(['w1', 'b1', 'w2', 'b2'])
PyTree操作
# 树映射 - 对所有叶子应用函数
scaled_params = jax.tree_map(lambda x: x * 0.9, params)
# 树归约
total_params = jax.tree_reduce(
lambda total, x: total + x.size,
params,
initializer=0
)
# 扁平化和反扁平化
flat, treedef = jax.tree_flatten(params)
reconstructed = jax.tree_unflatten(treedef, flat)
常见模式
优化循环
import jax
import jax.numpy as jnp
# 参数
params = jnp.array([1.0, 2.0, 3.0])
# 损失函数
def loss(params, x, y):
pred = jnp.dot(params, x)
return jnp.mean((pred - y)**2)
# 梯度函数
grad_fn = jax.jit(jax.grad(loss))
# 训练数据
x_train = jnp.ones((100, 3))
y_train = jnp.ones(100)
# 优化循环
learning_rate = 0.01
for step in range(1000):
grads = grad_fn(params, x_train, y_train)
params = params - learning_rate * grads
if step % 100 == 0:
l = loss(params, x_train, y_train)
print(f"步骤 {step}, 损失: {l:.4f}")
小批量训练
def train_step(params, batch):
"""单个训练步骤在一个批次上"""
x, y = batch
def batch_loss(params):
pred = jnp.dot(x, params)
return jnp.mean((pred - y)**2)
loss_value, grads = jax.value_and_grad(batch_loss)(params)
params = params - 0.01 * grads
return params, loss_value
# JIT编译训练步骤
train_step = jax.jit(train_step)
# 训练循环
for epoch in range(10):
for batch in data_loader:
params, loss = train_step(params, batch)
Scan用于高效循环
def cumulative_sum(xs):
"""使用scan的累积和"""
def step(carry, x):
new_carry = carry + x
output = new_carry
return new_carry, output
final_carry, outputs = jax.lax.scan(step, 0, xs)
return outputs
xs = jnp.array([1, 2, 3, 4, 5])
print(cumulative_sum(xs)) # [1 3 6 10 15]
带Scan的RNN
def rnn_step(carry, x):
"""单个RNN步骤"""
h = carry
h_new = jnp.tanh(jnp.dot(W_h, h) + jnp.dot(W_x, x))
return h_new, h_new
def rnn(xs, h0):
"""在序列上运行RNN"""
final_h, all_h = jax.lax.scan(rnn_step, h0, xs)
return all_h
# 处理序列
W_h = jnp.ones((5, 5))
W_x = jnp.ones((5, 3))
xs = jnp.ones((10, 3)) # 序列长度10
h0 = jnp.zeros(5)
outputs = rnn(xs, h0) # (10, 5)
性能最佳实践
1. JIT关键路径
# 将昂贵计算包装在jit中
@jax.jit
def expensive_fn(x):
for _ in range(100):
x = jnp.dot(x, x.T)
return x
# 避免JIT琐碎操作
def trivial(x):
return x + 1 # JIT开销不值得
2. 使用vmap替代循环
# 慢:Python循环
def slow_batch(xs):
return jnp.array([process(x) for x in xs])
# 快:vmap
fast_batch = jax.vmap(process)
3. 最小化主机-设备传输
# 错误:在循环中传输回主机
for i in range(1000):
x = compute_on_gpu(x)
print(float(x)) # 每次迭代传输到CPU!
# 正确:在结束时传输一次
for i in range(1000):
x = compute_on_gpu(x)
print(float(x)) # 单次传输
4. 使用合适精度
# 使用float32除非需要float64
x = jnp.array([1.0, 2.0, 3.0], dtype=jnp.float32)
# 仅在必要时启用float64
from jax import config
config.update("jax_enable_x64", True)
5. 尽可能预分配
# 错误:增长数组
result = jnp.array([])
for i in range(1000):
result = jnp.append(result, compute(i))
# 正确:预分配
result = jnp.zeros(1000)
for i in range(1000):
result = result.at[i].set(compute(i))
# 更好:使用scan或vmap
def compute_all(i):
return compute(i)
result = jax.vmap(compute_all)(jnp.arange(1000))
调试
检查数组值
# 在JIT上下文中使用jax.debug.print
@jax.jit
def debug_fn(x):
jax.debug.print("x = {}", x)
jax.debug.print("x shape = {}, dtype = {}", x.shape, x.dtype)
return x**2
梯度检查
from jax.test_util import check_grads
def f(x):
return jnp.sum(x**3)
x = jnp.array([1.0, 2.0, 3.0])
# 数值验证梯度
check_grads(f, (x,), order=2) # 检查一阶和二阶导数
检查编译代码
# 查看jaxpr(中间表示)
def f(x):
return x**2 + 3*x
jaxpr = jax.make_jaxpr(f)(1.0)
print(jaxpr)
# 查看HLO(低级编译代码)
compiled = jax.jit(f).lower(1.0).compile()
print(compiled.as_text())
常见陷阱 - 快速参考
| 陷阱 | NumPy | JAX | 解决方案 |
|---|---|---|---|
| 原地更新 | x[0] = 1 |
❌ 错误 | x = x.at[0].set(1) |
| 随机状态 | np.random.seed() |
❌ 不可靠 | key = jax.random.PRNGKey() |
| 列表输入 | np.sum([1,2,3]) |
❌ 错误 | jnp.sum(jnp.array([1,2,3])) |
| 越界 | IndexError | ⚠️ 静默截断 | 避免,验证索引 |
| 值依赖if | 工作 | ❌ 在JIT中 | jax.lax.cond() |
| 动态形状 | 工作 | ❌ 在JIT中 | 保持形状静态 |
| 默认精度 | float64 | float32 | 设置jax_enable_x64 |
| JIT中打印 | 工作 | 仅一次 | jax.debug.print() |
安装
# 仅CPU
pip install jax
# GPU(CUDA 12)
pip install jax[cuda12]
# TPU
pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
生态系统库
JAX在其变换原语上建立了丰富的生态系统:
神经网络:
- Flax - 官方神经网络库
- Haiku - DeepMind的神经网络库
- Equinox - 优雅的PyTorch风格库
优化:
- Optax - 梯度处理和优化
- JAXopt - 非线性优化
科学计算:
- JAX-MD - 分子动力学
- Diffrax - 微分方程求解器
- BlackJAX - MCMC采样
实用工具:
- jaxtyping - 数组类型注解
- chex - 测试实用工具
额外资源
- 官方文档: https://docs.jax.dev/
- JAX GitHub: https://github.com/google/jax
- JAX生态系统: https://github.com/google/jax#neural-network-libraries
- JAX教程(DeepMind): https://github.com/deepmind/jax-tutorial
- Awesome JAX: https://github.com/n2cholas/awesome-jax
相关技能
python-optimization- 使用scipy、pyomo进行数值优化python-ase- 原子模拟环境(可使用JAX进行力计算)pycse- 科学计算实用工具python-best-practices- JAX项目的代码质量