文档
FastAI 宠物品种分类:4 行代码 SOTA
目标
用 FastAI 的高级 API,仅 4 行核心代码训练一个 37 种宠物品种分类模型(Oxford-IIIT Pet Dataset)。
完整代码
from fastai.vision.all import *
# ─── 第 1 行:下载数据 ───
path = untar_data(URLs.PETS)
# ─── 第 2 行:声明式数据管道 ───
dls = ImageDataLoaders.from_name_re(
path, # 数据根目录
get_image_files(path/"images"), # 所有图片
pat=r'^(.*)_\d+\.jpg$', # 正则提取:品种_编号.jpg → 品种
item_tfms=Resize(460), # 先统一大小
batch_tfms=aug_transforms(size=224, min_scale=0.75), # 数据增强
bs=64, # batch size
)
# 查看数据
dls.show_batch(max_n=6, nrows=2)
# plt.show()
# ─── 第 3 行:创建学习器 ───
learn = vision_learner(
dls,
resnet34, # backbone:ResNet-34 预训练权重
metrics=[accuracy, error_rate],
# 自动:差分学习率、混合精度、LabelSmoothing
)
# ─── 第 4 行:训练! ───
learn.fine_tune(5) # 5 epoch:1 epoch 训练 head + 4 epoch 微调全部
# ─── 预测 ───
img = PILImage.create("my_cat.jpg")
breed, _, probs = learn.predict(img)
print(f"品种: {breed} (置信度: {probs.max():.2%})")
# ─── 混淆矩阵 ───
interp = ClassificationInterpretation.from_learner(learn)
interp.plot_confusion_matrix(figsize=(12, 12))
# interp.plot_top_losses(9) # 看最差的预测
运行步骤
pip install fastai
python fastai_pets.py
首次运行自动下载数据集(~800 MB)和 ResNet34 预训练权重(~80 MB)。
预期输出
epoch train_loss valid_loss error_rate accuracy time
0 1.245 0.678 0.212 0.788 00:45
1 0.567 0.398 0.124 0.876 00:48
2 0.345 0.312 0.098 0.902 00:47
...
epoch train_loss valid_loss error_rate accuracy time
0 0.234 0.278 0.089 0.911 01:02
...
4 0.056 0.198 0.064 0.936 01:01
品种: Ragdoll (置信度: 97.34%)