文档
ONNX 完整流程:PyTorch 导出 → ONNX Runtime 推理
目标
演示标准工作流:PyTorch 训一个 ResNet → 导出为 ONNX → ONNX Runtime 推理 → 验证一致性 + 性能对比。
完整代码
import torch
import torchvision.models as models
import numpy as np
import onnx
import onnxruntime as ort
import time
# ─── 1. PyTorch 模型 ───
model = models.resnet18(pretrained=True)
model.eval()
dummy_input = torch.randn(1, 3, 224, 224)
# PyTorch 推理一次
with torch.no_grad():
torch_out = model(dummy_input)
# ─── 2. 导出 ONNX ───
onnx_path = "resnet18.onnx"
torch.onnx.export(
model,
dummy_input,
onnx_path,
export_params=True, # 保存权重
opset_version=17, # ONNX opset
input_names=["input"], # 输入名
output_names=["output"], # 输出名
dynamic_axes={"input": {0: "batch"}, "output": {0: "batch"}}, # 动态 batch
)
# 验证 ONNX 模型有效性
onnx_model = onnx.load(onnx_path)
onnx.checker.check_model(onnx_model)
print(f"✅ ONNX 模型导出成功: {onnx_path} ({onnx_model.ByteSize() / 1e6:.1f} MB)")
# ─── 3. ONNX Runtime 推理 ───
session = ort.InferenceSession(
onnx_path,
providers=["CPUExecutionProvider"], # 或 ["CUDAExecutionProvider"]
)
# 准备输入
ort_inputs = {session.get_inputs()[0].name: dummy_input.numpy()}
# 预热 + 推理
_ = session.run(None, ort_inputs)
start = time.time()
for _ in range(100):
session.run(None, ort_inputs)
ort_time = time.time() - start
# ─── 4. 验证一致性 ───
ort_out = session.run(None, ort_inputs)[0]
diff = np.abs(torch_out.numpy() - ort_out).max()
print(f"PyTorch vs ONNX 最大误差: {diff:.2e}")
# ─── 5. 性能对比 ───
with torch.no_grad():
_ = model(dummy_input)
start = time.time()
for _ in range(100):
model(dummy_input)
torch_time = time.time() - start
print(f"\n性能对比 (100 次推理):")
print(f" PyTorch: {torch_time:.3f}s")
print(f" ONNX Runtime: {ort_time:.3f}s")
print(f" 加速比: {torch_time/ort_time:.2f}x")
运行步骤
pip install torch torchvision onnx onnxruntime numpy
python onnx_export_infer.py
预期输出
✅ ONNX 模型导出成功: resnet18.onnx (44.7 MB)
PyTorch vs ONNX 最大误差: 1.19e-07
性能对比 (100 次推理):
PyTorch: 2.834s
ONNX Runtime: 1.521s
加速比: 1.86x