JAX高性能数值计算Skill python-jax

JAX是一个Python库,专注于高性能数值计算,支持自动微分、即时编译(JIT)、向量化和GPU/TPU加速,适用于机器学习、深度学习和科学计算等场景。核心功能包括变换组合(grad、jit、vmap、pmap)、PyTrees处理以及优化循环模式。关键词:JAX, 自动微分, 高性能计算, 机器学习, 深度学习, Python, GPU加速, 数值优化, 科学计算, AI框架。

机器学习 0 次安装 1 次浏览 更新于 3/12/2026

name: python-jax description: JAX(Just After eXecution)的专家指导 - 具有自动微分、JIT编译、向量化和GPU/TPU加速的高性能数值计算;包括变换(grad、jit、vmap、pmap)、尖锐点、陷阱以及与NumPy的区别 allowed-tools: [“*”]

JAX - 高性能数值计算

概述

JAX 是一个Python库,用于加速器导向的数组计算和程序转换,专为高性能数值计算和大规模机器学习设计。它结合了熟悉的NumPy风格API和强大的函数变换,用于自动微分、编译、向量化和并行化。

核心价值: 编写类似NumPy的Python代码,并通过可组合的函数变换自动获取梯度、GPU/TPU加速、向量化和并行化——无需改变数学表示。

何时使用JAX

使用JAX时:

  • 需要自动微分进行优化或机器学习
  • 希望通过最小代码更改获得GPU/TPU加速
  • 需要高性能数值计算
  • 构建自定义梯度算法
  • 需要自动向量化或并行化函数
  • 研究需要灵活微分
  • 希望以函数式编程方式编写数值代码

不要使用时:

  • 简单的NumPy操作无需性能需求(开销不合理)
  • 重度依赖原地突变(JAX数组不可变)
  • 命令式、有状态的代码带副作用
  • 需要依赖运行时数据值的控制流(支持有限)
  • 使用需要NumPy数组的库(兼容性问题)

核心变换

JAX提供四个基础、可组合的变换:

1. jax.grad - 自动微分

使用反向模式自动微分计算梯度。

import jax
import jax.numpy as jnp

def loss(x):
    return jnp.sum(x**2)

# 获取梯度函数
grad_loss = jax.grad(loss)

x = jnp.array([1.0, 2.0, 3.0])
gradient = grad_loss(x)
print(gradient)  # [2. 4. 6.]

关键特性:

  • 返回计算梯度的函数
  • 可组合用于高阶导数
  • 适用于复杂嵌套结构(pytrees)

2. jax.jit - 即时编译

编译函数到XLA以实现显著加速。

import jax

@jax.jit
def fast_function(x):
    return jnp.sum(x**2 + 3*x + 1)

# 第一次调用:编译 + 执行(较慢)
result = fast_function(jnp.array([1.0, 2.0, 3.0]))

# 后续调用:缓存的编译版本(快得多)
result = fast_function(jnp.array([4.0, 5.0, 6.0]))

性能:

  • 典型数值函数加速10-100倍
  • 第一次调用较慢(编译开销)
  • 相同形状/数据类型时缓存

3. jax.vmap - 自动向量化

自动跨批次维度向量化函数。

import jax

def process_single(x):
    """处理单个示例"""
    return jnp.sum(x**2)

# 向量化以处理批次
process_batch = jax.vmap(process_single)

# 自动处理批次
batch = jnp.array([[1, 2], [3, 4], [5, 6]])
results = process_batch(batch)
print(results)  # [5 25 61]

好处:

  • 消除手动循环编写
  • 通常比显式循环快
  • 代码更干净、声明性更强

4. jax.pmap - 并行映射

跨多个设备(GPU/TPU)并行化。

import jax

@jax.pmap
def parallel_fn(x):
    return x**2

# 在所有可用设备上运行
devices = jax.devices()
x = jnp.arange(len(devices))
results = parallel_fn(x)

用于:

  • 多GPU/TPU计算
  • 训练中的数据并行
  • 大规模模拟

变换组合

JAX变换无缝组合:

# JIT编译函数的梯度
@jax.jit
def fast_loss(x):
    return jnp.sum(x**2)

grad_fast_loss = jax.grad(fast_loss)

# 梯度的JIT(同样有效)
fast_grad_loss = jax.jit(jax.grad(fast_loss))

# 向量化梯度
batch_grad_loss = jax.vmap(jax.grad(fast_loss))

# JIT + vmap + grad
fast_batch_grad = jax.jit(jax.vmap(jax.grad(fast_loss)))

顺序影响性能但不影响正确性。

JAX vs NumPy - 关键区别

1. 不可变性

NumPy(可变):

import numpy as np
x = np.array([1, 2, 3])
x[0] = 10  # 原地修改x
print(x)  # [10 2 3]

JAX(不可变):

import jax.numpy as jnp
x = jnp.array([1, 2, 3])
# x[0] = 10  # 错误:JAX数组不可变

# 使用函数式更新代替
x = x.at[0].set(10)  # 返回新数组
print(x)  # [10 2 3]

函数式更新:

# 设置值
x = x.at[0].set(10)

# 增加值
x = x.at[0].add(5)

# 乘以值
x = x.at[0].mul(2)

# 最小/最大
x = x.at[0].min(5)
x = x.at[0].max(5)

# 多索引
x = x.at[0, 1].set(10)
x = x.at[[0, 2]].set(10)
x = x.at[0:3].set(10)

2. 随机数生成

NumPy(全局状态):

import numpy as np
np.random.seed(42)
x = np.random.normal(size=3)
y = np.random.normal(size=3)  # 与x不同

JAX(显式密钥):

import jax
key = jax.random.PRNGKey(42)

# 分割密钥以独立随机
key, subkey = jax.random.split(key)
x = jax.random.normal(subkey, shape=(3,))

key, subkey = jax.random.split(key)
y = jax.random.normal(subkey, shape=(3,))

# 并行随机数
keys = jax.random.split(key, num=10)
samples = jax.vmap(lambda k: jax.random.normal(k, shape=(3,)))(keys)

密钥管理模式:

# 创建初始密钥
key = jax.random.PRNGKey(0)

# 使用前始终分割
key, subkey1, subkey2 = jax.random.split(key, 3)

# 使用子密钥进行随机操作
x = jax.random.normal(subkey1, shape=(10,))
y = jax.random.uniform(subkey2, shape=(10,))

# 绝不重用密钥!

3. 64位精度

JAX默认32位以优化性能。

import jax.numpy as jnp

x = jnp.array([1.0])
print(x.dtype)  # float32(默认)

# 全局启用64位
from jax import config
config.update("jax_enable_x64", True)

x = jnp.array([1.0])
print(x.dtype)  # float64

# 或使用环境变量
# JAX_ENABLE_X64=1 python script.py

4. 越界索引

NumPy(引发错误):

import numpy as np
x = np.array([1, 2, 3])
# x[10]  # IndexError

JAX(静默截断):

import jax.numpy as jnp
x = jnp.array([1, 2, 3])
print(x[10])  # 3(返回最后一个元素,无错误!)

# 更新也静默忽略
x = x.at[10].set(99)  # 无错误,无效果

⚠️ 警告: 这是未定义行为 - 避免依赖它!

5. 非数组输入

NumPy(接受列表):

import numpy as np
result = np.sum([1, 2, 3])  # 工作

JAX(需要数组):

import jax.numpy as jnp
# result = jnp.sum([1, 2, 3])  # 错误

# 显式转换
result = jnp.sum(jnp.array([1, 2, 3]))  # 工作

原因: 防止跟踪期间性能下降。

尖锐点与陷阱

1. 需纯函数

❌ 错误:副作用

counter = 0

@jax.jit
def impure_fn(x):
    global counter
    counter += 1  # 副作用 - 不可靠
    return x**2

# 第一次调用:counter=1(跟踪)
result = impure_fn(2.0)
# 后续调用:counter仍为1(使用缓存跟踪)
result = impure_fn(3.0)

✅ 正确:纯函数

@jax.jit
def pure_fn(x, counter):
    return x**2, counter + 1

result, counter = pure_fn(2.0, 0)
result, counter = pure_fn(3.0, counter)

2. 控制流限制

❌ 错误:值依赖控制流

@jax.jit
def conditional(x):
    if x > 0:  # 错误:x是跟踪器,非具体值
        return x**2
    else:
        return x**3

✅ 正确:使用jax.lax.cond

@jax.jit
def conditional(x):
    return jax.lax.cond(
        x > 0,
        lambda x: x**2,  # 真分支
        lambda x: x**3   # 假分支
    )

循环:

# ❌ 错误:JIT中的Python for循环(编译时展开)
@jax.jit
def loop_bad(x, n):
    for i in range(n):  # n必须为静态
        x = x + 1
    return x

# ✅ 正确:使用jax.lax.fori_loop
@jax.jit
def loop_good(x, n):
    def body(i, val):
        return val + 1
    return jax.lax.fori_loop(0, n, body, x)

while循环:

@jax.jit
def while_loop(x):
    def cond_fun(val):
        return val < 10

    def body_fun(val):
        return val + 1

    return jax.lax.while_loop(cond_fun, body_fun, x)

3. 动态形状

❌ 错误:形状依赖运行时值

@jax.jit
def dynamic_shape(x, mask):
    return x[mask]  # 错误:编译时输出形状未知

✅ 正确:使用jnp.where进行掩码

@jax.jit
def static_shape(x, mask):
    return jnp.where(mask, x, 0)  # 形状保持不变

4. 原地更新语义

❌ 错误:依赖共享引用

x = jnp.array([1, 2, 3])
y = x
x = x.at[0].set(10)
print(y[0])  # 仍为1(y未改变)

注意: .at返回新数组;不修改原数组。

5. JIT中的打印调试

❌ 错误:JIT上下文中的print()

@jax.jit
def debug(x):
    print(f"x = {x}")  # 仅打印一次(跟踪期间)
    return x**2

✅ 正确:使用jax.debug.print

@jax.jit
def debug(x):
    jax.debug.print("x = {}", x)  # 每次调用都打印
    return x**2

6. 通过离散操作的梯度

❌ 错误:通过argmax的梯度

def loss(x):
    idx = jnp.argmax(x)  # 离散 - 无梯度
    return x[idx]

# grad_loss = jax.grad(loss)  # 梯度处处为零

✅ 正确:使用可微分近似

def loss(x):
    # 软argmax(可微分)
    weights = jax.nn.softmax(x * temperature)
    return jnp.sum(x * weights)

grad_loss = jax.grad(loss)

自动微分

基本梯度

import jax
import jax.numpy as jnp

def f(x):
    return jnp.sum(x**2)

grad_f = jax.grad(f)

x = jnp.array([1.0, 2.0, 3.0])
print(grad_f(x))  # [2. 4. 6.]

值与梯度

# 高效计算值和梯度
value_and_grad_f = jax.value_and_grad(f)

value, gradient = value_and_grad_f(x)
print(f"值: {value}, 梯度: {gradient}")

多参数

def f(x, y):
    return jnp.sum(x**2 + y**3)

# 梯度相对于第一个参数(默认)
grad_f = jax.grad(f)
print(grad_f(x, y))

# 梯度相对于第二个参数
grad_f_wrt_y = jax.grad(f, argnums=1)
print(grad_f_wrt_y(x, y))

# 相对于两个参数的梯度
grad_f_both = jax.grad(f, argnums=(0, 1))
grad_x, grad_y = grad_f_both(x, y)

辅助数据

def loss_with_aux(x):
    loss = jnp.sum(x**2)
    aux_data = {'norm': jnp.linalg.norm(x), 'mean': jnp.mean(x)}
    return loss, aux_data

# 告知JAX辅助输出
grad_fn = jax.grad(loss_with_aux, has_aux=True)
gradient, aux = grad_fn(x)

print(f"梯度: {gradient}")
print(f"辅助: {aux}")

高阶导数

# 二阶导数(Hessian对角线)
def f(x):
    return jnp.sum(x**3)

grad_f = jax.grad(f)
hess_diag_f = jax.grad(lambda x: jnp.sum(grad_f(x) * x))

# 完整Hessian
hessian_f = jax.hessian(f)

x = jnp.array([1.0, 2.0, 3.0])
print(hessian_f(x))

Jacobian

def vector_fn(x):
    """向量到向量函数"""
    return jnp.array([x[0]**2, x[1]**3, x[0]*x[1]])

# 前向模式(输入少输出多时高效)
jacfwd = jax.jacfwd(vector_fn)

# 反向模式(输入多输出少时高效)
jacrev = jax.jacrev(vector_fn)

x = jnp.array([2.0, 3.0])
print(jacfwd(x))
print(jacrev(x))  # 相同结果

自定义梯度

@jax.custom_gradient
def f(x):
    # 前向传递
    result = jnp.exp(x)

    # 定义自定义梯度
    def grad_fn(g):
        # g是下游梯度
        # 返回相对于输入的梯度
        return g * 2 * result  # 自定义:2倍正常梯度

    return result, grad_fn

# 像正常一样使用
grad_f = jax.grad(f)

JIT编译

基本用法

import jax
import jax.numpy as jnp

def slow_fn(x):
    return jnp.sum(x**2 + 3*x + 1)

fast_fn = jax.jit(slow_fn)

# 或装饰器
@jax.jit
def fast_fn2(x):
    return jnp.sum(x**2 + 3*x + 1)

静态参数

@jax.jit
def fn(x, n):
    for i in range(n):  # n必须为静态
        x = x + 1
    return x

# 错误:n每次调用变化,触发重新编译
result = fn(x, 5)
result = fn(x, 10)  # 重新编译!

# 解决方案:标记为静态
@jax.jit(static_argnums=(1,))
def fn_static(x, n):
    for i in range(n):
        x = x + 1
    return x

# 现在正确工作
result = fn_static(x, 5)
result = fn_static(x, 10)  # n=10的单独编译

避免重新编译

# 错误:不同形状触发重新编译
@jax.jit
def process(x):
    return jnp.sum(x**2)

x1 = jnp.ones(10)    # 编译形状(10,)
x2 = jnp.ones(20)    # 重新编译形状(20,)
x3 = jnp.ones(10)    # 使用缓存版本

# 正确:一致形状
batch_size = 32
x1 = jnp.ones((batch_size, 10))
x2 = jnp.ones((batch_size, 10))  # 相同形状,缓存

向量化(vmap)

基本批处理

def process_single(x):
    """处理单个示例:标量输入"""
    return x**2 + 3*x

# 手动向量化
def process_batch_manual(xs):
    return jnp.array([process_single(x) for x in xs])

# 自动向量化
process_batch = jax.vmap(process_single)

batch = jnp.array([1.0, 2.0, 3.0, 4.0])
print(process_batch(batch))  # 更快更干净

批处理矩阵操作

def matrix_vector_product(matrix, vector):
    """单个矩阵-向量乘积"""
    return matrix @ vector

# 在向量上批处理
batch_mvp = jax.vmap(matrix_vector_product, in_axes=(None, 0))

A = jnp.ones((3, 3))
vectors = jnp.ones((10, 3))  # 10个向量

results = batch_mvp(A, vectors)  # (10, 3)

# 在两者上批处理
batch_both = jax.vmap(matrix_vector_product, in_axes=(0, 0))

matrices = jnp.ones((10, 3, 3))
results = batch_both(matrices, vectors)

嵌套vmap

# 在两个维度上批处理
def fn(x, y):
    return x * y

# 在第一个参数上批处理,然后在第二个上
fn_batch = jax.vmap(jax.vmap(fn, in_axes=(None, 0)), in_axes=(0, None))

x = jnp.array([1, 2, 3])       # (3,)
y = jnp.array([10, 20])        # (2,)
result = fn_batch(x, y)        # (3, 2)

# 等同于:x[:, None] * y[None, :]

PyTrees - 嵌套结构

JAX适用于嵌套Python容器(pytrees):

import jax
import jax.numpy as jnp

# 数组字典
params = {
    'w1': jnp.ones((10, 5)),
    'b1': jnp.zeros(5),
    'w2': jnp.ones((5, 1)),
    'b2': jnp.zeros(1)
}

# 梯度适用于整个pytree
def loss(params, x):
    h = x @ params['w1'] + params['b1']
    h = jax.nn.relu(h)
    out = h @ params['w2'] + params['b2']
    return jnp.mean(out**2)

grad_fn = jax.grad(loss)
x = jnp.ones((32, 10))

# 返回相同结构的梯度
grads = grad_fn(params, x)
print(grads.keys())  # dict_keys(['w1', 'b1', 'w2', 'b2'])

PyTree操作

# 树映射 - 对所有叶子应用函数
scaled_params = jax.tree_map(lambda x: x * 0.9, params)

# 树归约
total_params = jax.tree_reduce(
    lambda total, x: total + x.size,
    params,
    initializer=0
)

# 扁平化和反扁平化
flat, treedef = jax.tree_flatten(params)
reconstructed = jax.tree_unflatten(treedef, flat)

常见模式

优化循环

import jax
import jax.numpy as jnp

# 参数
params = jnp.array([1.0, 2.0, 3.0])

# 损失函数
def loss(params, x, y):
    pred = jnp.dot(params, x)
    return jnp.mean((pred - y)**2)

# 梯度函数
grad_fn = jax.jit(jax.grad(loss))

# 训练数据
x_train = jnp.ones((100, 3))
y_train = jnp.ones(100)

# 优化循环
learning_rate = 0.01
for step in range(1000):
    grads = grad_fn(params, x_train, y_train)
    params = params - learning_rate * grads

    if step % 100 == 0:
        l = loss(params, x_train, y_train)
        print(f"步骤 {step}, 损失: {l:.4f}")

小批量训练

def train_step(params, batch):
    """单个训练步骤在一个批次上"""
    x, y = batch

    def batch_loss(params):
        pred = jnp.dot(x, params)
        return jnp.mean((pred - y)**2)

    loss_value, grads = jax.value_and_grad(batch_loss)(params)
    params = params - 0.01 * grads
    return params, loss_value

# JIT编译训练步骤
train_step = jax.jit(train_step)

# 训练循环
for epoch in range(10):
    for batch in data_loader:
        params, loss = train_step(params, batch)

Scan用于高效循环

def cumulative_sum(xs):
    """使用scan的累积和"""
    def step(carry, x):
        new_carry = carry + x
        output = new_carry
        return new_carry, output

    final_carry, outputs = jax.lax.scan(step, 0, xs)
    return outputs

xs = jnp.array([1, 2, 3, 4, 5])
print(cumulative_sum(xs))  # [1 3 6 10 15]

带Scan的RNN

def rnn_step(carry, x):
    """单个RNN步骤"""
    h = carry
    h_new = jnp.tanh(jnp.dot(W_h, h) + jnp.dot(W_x, x))
    return h_new, h_new

def rnn(xs, h0):
    """在序列上运行RNN"""
    final_h, all_h = jax.lax.scan(rnn_step, h0, xs)
    return all_h

# 处理序列
W_h = jnp.ones((5, 5))
W_x = jnp.ones((5, 3))
xs = jnp.ones((10, 3))  # 序列长度10
h0 = jnp.zeros(5)

outputs = rnn(xs, h0)  # (10, 5)

性能最佳实践

1. JIT关键路径

# 将昂贵计算包装在jit中
@jax.jit
def expensive_fn(x):
    for _ in range(100):
        x = jnp.dot(x, x.T)
    return x

# 避免JIT琐碎操作
def trivial(x):
    return x + 1  # JIT开销不值得

2. 使用vmap替代循环

# 慢:Python循环
def slow_batch(xs):
    return jnp.array([process(x) for x in xs])

# 快:vmap
fast_batch = jax.vmap(process)

3. 最小化主机-设备传输

# 错误:在循环中传输回主机
for i in range(1000):
    x = compute_on_gpu(x)
    print(float(x))  # 每次迭代传输到CPU!

# 正确:在结束时传输一次
for i in range(1000):
    x = compute_on_gpu(x)
print(float(x))  # 单次传输

4. 使用合适精度

# 使用float32除非需要float64
x = jnp.array([1.0, 2.0, 3.0], dtype=jnp.float32)

# 仅在必要时启用float64
from jax import config
config.update("jax_enable_x64", True)

5. 尽可能预分配

# 错误:增长数组
result = jnp.array([])
for i in range(1000):
    result = jnp.append(result, compute(i))

# 正确:预分配
result = jnp.zeros(1000)
for i in range(1000):
    result = result.at[i].set(compute(i))

# 更好:使用scan或vmap
def compute_all(i):
    return compute(i)

result = jax.vmap(compute_all)(jnp.arange(1000))

调试

检查数组值

# 在JIT上下文中使用jax.debug.print
@jax.jit
def debug_fn(x):
    jax.debug.print("x = {}", x)
    jax.debug.print("x shape = {}, dtype = {}", x.shape, x.dtype)
    return x**2

梯度检查

from jax.test_util import check_grads

def f(x):
    return jnp.sum(x**3)

x = jnp.array([1.0, 2.0, 3.0])

# 数值验证梯度
check_grads(f, (x,), order=2)  # 检查一阶和二阶导数

检查编译代码

# 查看jaxpr(中间表示)
def f(x):
    return x**2 + 3*x

jaxpr = jax.make_jaxpr(f)(1.0)
print(jaxpr)

# 查看HLO(低级编译代码)
compiled = jax.jit(f).lower(1.0).compile()
print(compiled.as_text())

常见陷阱 - 快速参考

陷阱 NumPy JAX 解决方案
原地更新 x[0] = 1 ❌ 错误 x = x.at[0].set(1)
随机状态 np.random.seed() ❌ 不可靠 key = jax.random.PRNGKey()
列表输入 np.sum([1,2,3]) ❌ 错误 jnp.sum(jnp.array([1,2,3]))
越界 IndexError ⚠️ 静默截断 避免,验证索引
值依赖if 工作 ❌ 在JIT中 jax.lax.cond()
动态形状 工作 ❌ 在JIT中 保持形状静态
默认精度 float64 float32 设置jax_enable_x64
JIT中打印 工作 仅一次 jax.debug.print()

安装

# 仅CPU
pip install jax

# GPU(CUDA 12)
pip install jax[cuda12]

# TPU
pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

生态系统库

JAX在其变换原语上建立了丰富的生态系统:

神经网络:

  • Flax - 官方神经网络库
  • Haiku - DeepMind的神经网络库
  • Equinox - 优雅的PyTorch风格库

优化:

  • Optax - 梯度处理和优化
  • JAXopt - 非线性优化

科学计算:

  • JAX-MD - 分子动力学
  • Diffrax - 微分方程求解器
  • BlackJAX - MCMC采样

实用工具:

  • jaxtyping - 数组类型注解
  • chex - 测试实用工具

额外资源

相关技能

  • python-optimization - 使用scipy、pyomo进行数值优化
  • python-ase - 原子模拟环境(可使用JAX进行力计算)
  • pycse - 科学计算实用工具
  • python-best-practices - JAX项目的代码质量