JAX 纯函数式思维与神经网络

知识库
知识库文档
/tech-stacks/jax/tutorial/JAX 纯函数式思维与神经网络.md

文档

JAX 入门教程:纯函数式思维与神经网络

1. JAX 的哲学

JAX 不是另一个 TensorFlow 或 PyTorch。它的核心是函数变换:把普通的 Python/NumPy 函数,变换为可微分、可编译、可并行的版本。

你的函数 f  ──► jax.grad(f)   ──► 自动求导
            ──► jax.jit(f)    ──► XLA 编译加速
            ──► jax.vmap(f)   ──► 批量向量化

这三大变换可以任意组合:jax.jit(jax.grad(jax.vmap(f)))

2. RNG:显式随机状态

JAX 没有全局随机状态,必须显式传递 PRNGKey:

key = jax.random.PRNGKey(42)     # 种子
key, subkey = jax.random.split(key)
weights = jax.random.normal(subkey, (784, 256))

这是为了纯函数性——同样的 key 产生同样的结果,方便复现与并行。

3. 不可变性陷阱

x = jnp.array([1, 2, 3])
x[0] = 99   # ❌ 报错!JAX 数组不可变

x = x.at[0].set(99)  # ✅ 返回新数组

这起初让人沮丧,但确保了函数无副作用,编译器可做激进优化。

4. Flax:JAX 生态的神经网络库(推荐)

import flax.linen as nn
import optax

class MLP(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(256)(x)
        x = nn.relu(x)
        x = nn.Dense(10)(x)
        return x

model = MLP()
params = model.init(jax.random.PRNGKey(0), jnp.ones((1, 784)))

def loss_fn(params, x, y):
    logits = model.apply(params, x)
    return jnp.mean(optax.softmax_cross_entropy_with_integer_labels(logits, y))

grads = jax.grad(loss_fn)(params, x_batch, y_batch)

5. JAX vs PyTorch vs TensorFlow 选型

维度 JAX PyTorch TensorFlow
编程范式 纯函数式 命令式 OOP 符号图+Eager
调试 一般(jit 内难 print) 优秀(Python 原生) 一般
TPU 支持 ⭐⭐⭐ 最佳 ⭐⭐
社区热度 研究界 全领域 产业界
学习曲线 陡峭 平缓 中等

6. 什么时候选 JAX?

  • 你在搞研究(DeepMind 全在用)
  • 需要自定义复杂梯度(双层优化、元学习)
  • TPU 是你的主战场
  • 你喜欢函数式编程

思考题

  1. 为什么 JAX 要坚持"不可变数组"?好处是什么?
  2. jax.jit(jax.grad(f))jax.grad(jax.jit(f)) 行为有何不同?
  3. 用 JAX 实现一个简单的 linear regression 训练循环。

信息

路径
/tech-stacks/jax/tutorial/JAX 纯函数式思维与神经网络.md
更新时间
2026/5/31