JAX 入门教程:纯函数式思维与神经网络
1. JAX 的哲学
JAX 不是另一个 TensorFlow 或 PyTorch。它的核心是函数变换:把普通的 Python/NumPy 函数,变换为可微分、可编译、可并行的版本。
你的函数 f ──► jax.grad(f) ──► 自动求导
──► jax.jit(f) ──► XLA 编译加速
──► jax.vmap(f) ──► 批量向量化
这三大变换可以任意组合:jax.jit(jax.grad(jax.vmap(f))) ✅
2. RNG:显式随机状态
JAX 没有全局随机状态,必须显式传递 PRNGKey:
key = jax.random.PRNGKey(42) # 种子
key, subkey = jax.random.split(key)
weights = jax.random.normal(subkey, (784, 256))
这是为了纯函数性——同样的 key 产生同样的结果,方便复现与并行。
3. 不可变性陷阱
x = jnp.array([1, 2, 3])
x[0] = 99 # ❌ 报错!JAX 数组不可变
x = x.at[0].set(99) # ✅ 返回新数组
这起初让人沮丧,但确保了函数无副作用,编译器可做激进优化。
4. Flax:JAX 生态的神经网络库(推荐)
import flax.linen as nn
import optax
class MLP(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Dense(256)(x)
x = nn.relu(x)
x = nn.Dense(10)(x)
return x
model = MLP()
params = model.init(jax.random.PRNGKey(0), jnp.ones((1, 784)))
def loss_fn(params, x, y):
logits = model.apply(params, x)
return jnp.mean(optax.softmax_cross_entropy_with_integer_labels(logits, y))
grads = jax.grad(loss_fn)(params, x_batch, y_batch)
5. JAX vs PyTorch vs TensorFlow 选型
| 维度 |
JAX |
PyTorch |
TensorFlow |
| 编程范式 |
纯函数式 |
命令式 OOP |
符号图+Eager |
| 调试 |
一般(jit 内难 print) |
优秀(Python 原生) |
一般 |
| TPU 支持 |
⭐⭐⭐ 最佳 |
⭐ |
⭐⭐ |
| 社区热度 |
研究界 |
全领域 |
产业界 |
| 学习曲线 |
陡峭 |
平缓 |
中等 |
6. 什么时候选 JAX?
- 你在搞研究(DeepMind 全在用)
- 需要自定义复杂梯度(双层优化、元学习)
- TPU 是你的主战场
- 你喜欢函数式编程
思考题
- 为什么 JAX 要坚持"不可变数组"?好处是什么?
jax.jit(jax.grad(f)) 和 jax.grad(jax.jit(f)) 行为有何不同?
- 用 JAX 实现一个简单的 linear regression 训练循环。