文档
JAX 自动微分 + JIT Hello World
目标
体验 JAX 的三大核心变换:grad(自动求导)、jit(编译加速)、vmap(自动向量化)。
完整代码
import jax
import jax.numpy as jnp
import time
# ───────── 1. 基础:jax.numpy 与 NumPy 几乎一样 ─────────
x = jnp.array([1.0, 2.0, 3.0, 4.0])
print("x:", x)
print("sin(x):", jnp.sin(x))
# ───────── 2. grad:自动求导 ─────────
def f(x):
"""f(x) = x³ + 2x² + x"""
return x**3 + 2 * x**2 + x
df = jax.grad(f) # 一阶导数
d2f = jax.grad(df) # 二阶导数
print(f"\nf(2) = {f(2.0)}")
print(f"f'(2) = {df(2.0)} (期望: 3*4 + 4*2 + 1 = 21)")
print(f"f''(2) = {d2f(2.0)} (期望: 6*2 + 4 = 16)")
# ───────── 3. jit:即时编译加速 ─────────
@jax.jit
def matmul(a, b):
return jnp.dot(a, b)
a = jnp.ones((1000, 1000))
b = jnp.ones((1000, 1000))
# 第一次调用会编译(慢),第二次起飞
_ = matmul(a, b) # 预热
start = time.time()
for _ in range(100):
matmul(a, b)
print(f"\njit 加速 100 次矩阵乘法: {time.time() - start:.3f}s")
# ───────── 4. vmap:自动向量化 ─────────
def apply_kernel(x):
"""对单个元素操作"""
return jnp.exp(x) / (1 + jnp.exp(x)) # sigmoid
batch = jnp.linspace(-5, 5, 10)
sigmoid_batch = jax.vmap(apply_kernel)(batch)
print("\nvmap sigmoid([-5..5]):", sigmoid_batch)
运行步骤
pip install jax jaxlib
python jax_demo.py
预期输出
x: [1. 2. 3. 4.]
sin(x): [0.8415 0.9093 0.1411 -0.7568]
f(2) = 14.0
f'(2) = 21.0 (期望: 3*4 + 4*2 + 1 = 21)
f''(2) = 16.0 (期望: 6*2 + 4 = 16)
jit 加速 100 次矩阵乘法: 0.042s
vmap sigmoid([-5..5]): [0.0067 0.0179 ... 0.9933]