文档
JAX 入门教程:纯函数式思维与神经网络
1. JAX 的哲学
JAX 不是另一个 TensorFlow 或 PyTorch。它的核心是函数变换:把普通的 Python/NumPy 函数,变换为可微分、可编译、可并行的版本。
你的函数 f ──► jax.grad(f) ──► 自动求导
──► jax.jit(f) ──► XLA 编译加速
──► jax.vmap(f) ──► 批量向量化
这三大变换可以任意组合:jax.jit(jax.grad(jax.vmap(f))) ✅
2. RNG:显式随机状态
JAX 没有全局随机状态,必须显式传递 PRNGKey:
key = jax.random.PRNGKey(42) # 种子
key, subkey = jax.random.split(key)
weights = jax.random.normal(subkey, (784, 256))
这是为了纯函数性——同样的 key 产生同样的结果,方便复现与并行。
3. 不可变性陷阱
x = jnp.array([1, 2, 3])
x[0] = 99 # ❌ 报错!JAX 数组不可变
x = x.at[0].set(99) # ✅ 返回新数组
这起初让人沮丧,但确保了函数无副作用,编译器可做激进优化。
4. Flax:JAX 生态的神经网络库(推荐)
import flax.linen as nn
import optax
class MLP(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Dense(256)(x)
x = nn.relu(x)
x = nn.Dense(10)(x)
return x
model = MLP()
params = model.init(jax.random.PRNGKey(0), jnp.ones((1, 784)))
def loss_fn(params, x, y):
logits = model.apply(params, x)
return jnp.mean(optax.softmax_cross_entropy_with_integer_labels(logits, y))
grads = jax.grad(loss_fn)(params, x_batch, y_batch)
5. JAX vs PyTorch vs TensorFlow 选型
| 维度 | JAX | PyTorch | TensorFlow |
|---|---|---|---|
| 编程范式 | 纯函数式 | 命令式 OOP | 符号图+Eager |
| 调试 | 一般(jit 内难 print) | 优秀(Python 原生) | 一般 |
| TPU 支持 | ⭐⭐⭐ 最佳 | ⭐ | ⭐⭐ |
| 社区热度 | 研究界 | 全领域 | 产业界 |
| 学习曲线 | 陡峭 | 平缓 | 中等 |
6. 什么时候选 JAX?
- 你在搞研究(DeepMind 全在用)
- 需要自定义复杂梯度(双层优化、元学习)
- TPU 是你的主战场
- 你喜欢函数式编程
思考题
- 为什么 JAX 要坚持"不可变数组"?好处是什么?
jax.jit(jax.grad(f))和jax.grad(jax.jit(f))行为有何不同?- 用 JAX 实现一个简单的 linear regression 训练循环。