JAX

技术栈
AI 框架
deep-learningnumerical-computingautodiffxlatpugoogle

概览

JAX

JAX 是 Google Research 开源的高性能数值计算 + 自动微分框架。它的核心哲学是"NumPy + 函数式变换":jit(编译加速)、grad(自动求导)、vmap(自动向量化)。

核心价值:

  • JIT 编译@jax.jit 将 Python 函数编译为 XLA 高效代码,TPU/GPU 极致加速
  • 函数式微分jax.grad 求任意阶导数,jax.vmap 消除手写 batch 循环
  • NumPy 兼容jax.numpy 与 NumPy API 几乎一致,无缝迁移
  • SOTA 生态:Haiku / Flax / Trax 三大神经网络库,DeepMind / Google Brain 广泛使用

适用场景: 研究型项目、需要自定义梯度、TPU 训练、强化学习、扩散模型。

安装

环境准备

  • Python:>= 3.9(推荐 3.10 / 3.11)
  • 硬件:TPU 最佳,NVIDIA GPU(CUDA 12+)/ Apple Silicon(Metal)也支持
  • pip:>= 23.0

安装命令

CPU 版

pip install jax jaxlib

NVIDIA GPU 版

# CUDA 12
pip install jax[cuda12] jaxlib

# CUDA 11
pip install jax[cuda11] jaxlib

Apple Silicon (M1/M2/M3)

pip install jax jaxlib
# 自动启用 Metal 加速

验证安装

import jax
import jax.numpy as jnp

print("JAX 版本:", jax.__version__)
print("可用设备:", jax.devices())

# 简单测试
x = jnp.array([1.0, 2.0, 3.0])
print("jit 加法:", jax.jit(lambda a: a + 1)(x))

常见安装问题

Q1: jaxlib 安装失败(找不到匹配 wheel)

JAX 官方不提供所有 Python 版本的预编译 wheel。降级 Python 到 3.10/3.11,或从源码编译 jaxlib。

Q2: GPU 不可见

检查 CUDA 版本:nvcc --version。JAX GPU 需 CUDA 12+ 且 cuDNN 9+。查看 https://jax.readthedocs.io/en/latest/installation.html

Q3: RuntimeError: Unknown backend: 'gpu'

jaxlib 安装的是 CPU 版。重新安装 pip install jax[cuda12] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

示例

JAX 自动微分 + JIT Hello World

目标

体验 JAX 的三大核心变换:grad(自动求导)、jit(编译加速)、vmap(自动向量化)。

完整代码

import jax
import jax.numpy as jnp
import time

# ───────── 1. 基础:jax.numpy 与 NumPy 几乎一样 ─────────
x = jnp.array([1.0, 2.0, 3.0, 4.0])
print("x:", x)
print("sin(x):", jnp.sin(x))

# ───────── 2. grad:自动求导 ─────────
def f(x):
    """f(x) = x³ + 2x² + x"""
    return x**3 + 2 * x**2 + x

df = jax.grad(f)          # 一阶导数
d2f = jax.grad(df)        # 二阶导数
print(f"\nf(2) = {f(2.0)}")
print(f"f'(2) = {df(2.0)} (期望: 3*4 + 4*2 + 1 = 21)")
print(f"f''(2) = {d2f(2.0)} (期望: 6*2 + 4 = 16)")

# ───────── 3. jit:即时编译加速 ─────────
@jax.jit
def matmul(a, b):
    return jnp.dot(a, b)

a = jnp.ones((1000, 1000))
b = jnp.ones((1000, 1000))

# 第一次调用会编译(慢),第二次起飞
_ = matmul(a, b)  # 预热
start = time.time()
for _ in range(100):
    matmul(a, b)
print(f"\njit 加速 100 次矩阵乘法: {time.time() - start:.3f}s")

# ───────── 4. vmap:自动向量化 ─────────
def apply_kernel(x):
    """对单个元素操作"""
    return jnp.exp(x) / (1 + jnp.exp(x))  # sigmoid

batch = jnp.linspace(-5, 5, 10)
sigmoid_batch = jax.vmap(apply_kernel)(batch)
print("\nvmap sigmoid([-5..5]):", sigmoid_batch)

运行步骤

pip install jax jaxlib
python jax_demo.py

预期输出

x: [1. 2. 3. 4.]
sin(x): [0.8415 0.9093 0.1411 -0.7568]

f(2) = 14.0
f'(2) = 21.0 (期望: 3*4 + 4*2 + 1 = 21)
f''(2) = 16.0 (期望: 6*2 + 4 = 16)

jit 加速 100 次矩阵乘法: 0.042s

vmap sigmoid([-5..5]): [0.0067 0.0179 ... 0.9933]

教程

JAX 入门教程:纯函数式思维与神经网络

1. JAX 的哲学

JAX 不是另一个 TensorFlow 或 PyTorch。它的核心是函数变换:把普通的 Python/NumPy 函数,变换为可微分、可编译、可并行的版本。

你的函数 f  ──► jax.grad(f)   ──► 自动求导
            ──► jax.jit(f)    ──► XLA 编译加速
            ──► jax.vmap(f)   ──► 批量向量化

这三大变换可以任意组合:jax.jit(jax.grad(jax.vmap(f)))

2. RNG:显式随机状态

JAX 没有全局随机状态,必须显式传递 PRNGKey:

key = jax.random.PRNGKey(42)     # 种子
key, subkey = jax.random.split(key)
weights = jax.random.normal(subkey, (784, 256))

这是为了纯函数性——同样的 key 产生同样的结果,方便复现与并行。

3. 不可变性陷阱

x = jnp.array([1, 2, 3])
x[0] = 99   # ❌ 报错!JAX 数组不可变

x = x.at[0].set(99)  # ✅ 返回新数组

这起初让人沮丧,但确保了函数无副作用,编译器可做激进优化。

4. Flax:JAX 生态的神经网络库(推荐)

import flax.linen as nn
import optax

class MLP(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(256)(x)
        x = nn.relu(x)
        x = nn.Dense(10)(x)
        return x

model = MLP()
params = model.init(jax.random.PRNGKey(0), jnp.ones((1, 784)))

def loss_fn(params, x, y):
    logits = model.apply(params, x)
    return jnp.mean(optax.softmax_cross_entropy_with_integer_labels(logits, y))

grads = jax.grad(loss_fn)(params, x_batch, y_batch)

5. JAX vs PyTorch vs TensorFlow 选型

维度 JAX PyTorch TensorFlow
编程范式 纯函数式 命令式 OOP 符号图+Eager
调试 一般(jit 内难 print) 优秀(Python 原生) 一般
TPU 支持 ⭐⭐⭐ 最佳 ⭐⭐
社区热度 研究界 全领域 产业界
学习曲线 陡峭 平缓 中等

6. 什么时候选 JAX?

  • 你在搞研究(DeepMind 全在用)
  • 需要自定义复杂梯度(双层优化、元学习)
  • TPU 是你的主战场
  • 你喜欢函数式编程

思考题

  1. 为什么 JAX 要坚持"不可变数组"?好处是什么?
  2. jax.jit(jax.grad(f))jax.grad(jax.jit(f)) 行为有何不同?
  3. 用 JAX 实现一个简单的 linear regression 训练循环。

参考资料

暂无参考文献