4 行代码训练宠物品种分类器

知识库
知识库文档
/tech-stacks/fastai/examples/4 行代码训练宠物品种分类器.md

文档

FastAI 宠物品种分类:4 行代码 SOTA

目标

用 FastAI 的高级 API,仅 4 行核心代码训练一个 37 种宠物品种分类模型(Oxford-IIIT Pet Dataset)。

完整代码

from fastai.vision.all import *

# ─── 第 1 行:下载数据 ───
path = untar_data(URLs.PETS)

# ─── 第 2 行:声明式数据管道 ───
dls = ImageDataLoaders.from_name_re(
    path,                           # 数据根目录
    get_image_files(path/"images"), # 所有图片
    pat=r'^(.*)_\d+\.jpg$',        # 正则提取:品种_编号.jpg → 品种
    item_tfms=Resize(460),          # 先统一大小
    batch_tfms=aug_transforms(size=224, min_scale=0.75),  # 数据增强
    bs=64,                          # batch size
)

# 查看数据
dls.show_batch(max_n=6, nrows=2)
# plt.show()

# ─── 第 3 行:创建学习器 ───
learn = vision_learner(
    dls,
    resnet34,                   # backbone:ResNet-34 预训练权重
    metrics=[accuracy, error_rate],
    # 自动:差分学习率、混合精度、LabelSmoothing
)

# ─── 第 4 行:训练! ───
learn.fine_tune(5)              # 5 epoch:1 epoch 训练 head + 4 epoch 微调全部

# ─── 预测 ───
img = PILImage.create("my_cat.jpg")
breed, _, probs = learn.predict(img)
print(f"品种: {breed} (置信度: {probs.max():.2%})")

# ─── 混淆矩阵 ───
interp = ClassificationInterpretation.from_learner(learn)
interp.plot_confusion_matrix(figsize=(12, 12))
# interp.plot_top_losses(9)  # 看最差的预测

运行步骤

pip install fastai
python fastai_pets.py

首次运行自动下载数据集(~800 MB)和 ResNet34 预训练权重(~80 MB)。

预期输出

epoch   train_loss  valid_loss  error_rate  accuracy  time
0       1.245       0.678       0.212       0.788     00:45
1       0.567       0.398       0.124       0.876     00:48
2       0.345       0.312       0.098       0.902     00:47
...
epoch   train_loss  valid_loss  error_rate  accuracy  time
0       0.234       0.278       0.089       0.911     01:02
...
4       0.056       0.198       0.064       0.936     01:01

品种: Ragdoll (置信度: 97.34%)

信息

路径
/tech-stacks/fastai/examples/4 行代码训练宠物品种分类器.md
更新时间
2026/5/31