PyTorch 训练 → ONNX 导出 → 推理加速

知识库
知识库文档
/tech-stacks/onnx/examples/PyTorch 训练 → ONNX 导出 → 推理加速.md

文档

ONNX 完整流程:PyTorch 导出 → ONNX Runtime 推理

目标

演示标准工作流:PyTorch 训一个 ResNet → 导出为 ONNX → ONNX Runtime 推理 → 验证一致性 + 性能对比。

完整代码

import torch
import torchvision.models as models
import numpy as np
import onnx
import onnxruntime as ort
import time

# ─── 1. PyTorch 模型 ───
model = models.resnet18(pretrained=True)
model.eval()

dummy_input = torch.randn(1, 3, 224, 224)

# PyTorch 推理一次
with torch.no_grad():
    torch_out = model(dummy_input)

# ─── 2. 导出 ONNX ───
onnx_path = "resnet18.onnx"
torch.onnx.export(
    model,
    dummy_input,
    onnx_path,
    export_params=True,          # 保存权重
    opset_version=17,            # ONNX opset
    input_names=["input"],       # 输入名
    output_names=["output"],     # 输出名
    dynamic_axes={"input": {0: "batch"}, "output": {0: "batch"}},  # 动态 batch
)

# 验证 ONNX 模型有效性
onnx_model = onnx.load(onnx_path)
onnx.checker.check_model(onnx_model)
print(f"✅ ONNX 模型导出成功: {onnx_path} ({onnx_model.ByteSize() / 1e6:.1f} MB)")

# ─── 3. ONNX Runtime 推理 ───
session = ort.InferenceSession(
    onnx_path,
    providers=["CPUExecutionProvider"],  # 或 ["CUDAExecutionProvider"]
)

# 准备输入
ort_inputs = {session.get_inputs()[0].name: dummy_input.numpy()}

# 预热 + 推理
_ = session.run(None, ort_inputs)
start = time.time()
for _ in range(100):
    session.run(None, ort_inputs)
ort_time = time.time() - start

# ─── 4. 验证一致性 ───
ort_out = session.run(None, ort_inputs)[0]
diff = np.abs(torch_out.numpy() - ort_out).max()
print(f"PyTorch vs ONNX 最大误差: {diff:.2e}")

# ─── 5. 性能对比 ───
with torch.no_grad():
    _ = model(dummy_input)
    start = time.time()
    for _ in range(100):
        model(dummy_input)
    torch_time = time.time() - start

print(f"\n性能对比 (100 次推理):")
print(f"  PyTorch:      {torch_time:.3f}s")
print(f"  ONNX Runtime: {ort_time:.3f}s")
print(f"  加速比:       {torch_time/ort_time:.2f}x")

运行步骤

pip install torch torchvision onnx onnxruntime numpy
python onnx_export_infer.py

预期输出

✅ ONNX 模型导出成功: resnet18.onnx (44.7 MB)
PyTorch vs ONNX 最大误差: 1.19e-07

性能对比 (100 次推理):
  PyTorch:      2.834s
  ONNX Runtime: 1.521s
  加速比:       1.86x

信息

路径
/tech-stacks/onnx/examples/PyTorch 训练 → ONNX 导出 → 推理加速.md
更新时间
2026/5/31