文档
FastAI 入门教程:DataBlock 与 lr_find
1. FastAI 的哲学
FastAI 的三条原则:
- 类型派发(type dispatch):框架检测你的数据类型(图像/文本/表格),自动应用最佳增强和预处理
- 最佳实践内置:
fine_tune()内含差分学习率、渐进式解冻、1-cycle 策略 - 声明式 > 命令式:描述"数据长什么样"而非"如何加载数据"
2. DataBlock:声明式数据管道
from fastai.vision.all import *
pets = DataBlock(
blocks=(ImageBlock, CategoryBlock), # 输入是图像,输出是类别
get_items=get_image_files, # 如何获取所有样本
splitter=RandomSplitter(seed=42), # 如何划分训练/验证
get_y=using_attr(RegexLabeller(r'^(.*)_\d+\.jpg$'), 'name'), # 如何获取标签
item_tfms=Resize(460), # 每个样本的变换
batch_tfms=aug_transforms(size=224), # 批量变换(含数据增强)
)
dls = pets.dataloaders(path/"images", bs=64)
DataBlock 的五要素:
| 要素 | API | 说明 |
|---|---|---|
| blocks | (ImageBlock, CategoryBlock) |
输入/输出类型 |
| get_items | get_image_files |
如何找数据 |
| splitter | RandomSplitter(0.2) |
如何分集 |
| get_y | RegexLabeller(...) |
如何提取标签 |
| transforms | item_tfms / batch_tfms |
数据变换 |
3. lr_find:自动找最佳学习率
learn = vision_learner(dls, resnet34, metrics=accuracy)
lr_min, lr_steep = learn.lr_find(suggest_funcs=(minimum, steep))
print(f"建议 lr: {lr_steep:.2e}")
规则:选择 loss 曲线最陡下降点,或 loss 最小点 ÷ 10。FastAI 的 fine_tune() 自动调用 lr_find。
4. fine_tune 内部原理
# learn.fine_tune(5) 等价于:
# Phase 1: 冻结 backbone,只训练 head(1 epoch)
learn.fit_one_cycle(1, lr_max=lr) # 使用 lr_find 找到的 lr
# Phase 2: 解冻全部,差分学习率微调(4 epoch)
learn.unfreeze()
learn.fit_one_cycle(4, lr_max=slice(lr/100, lr/10)) # head 用 lr/10,backbone 用 lr/100
slice(a, b) 表示:最后层 lr=b,第一层 lr=a,中间指数衰减。这是 FastAI 最精妙的设计之一。
5. 常见任务速查
| 任务 | 模块 | 典型代码 |
|---|---|---|
| 图像分类 | fastai.vision |
vision_learner(dls, resnet34) |
| 图像分割 | fastai.vision |
unet_learner(dls, resnet34) |
| 文本分类 | fastai.text |
text_classifier_learner(dls, AWD_LSTM) |
| 表格数据 | fastai.tabular |
tabular_learner(dls, layers=[200,100]) |
| 协同过滤 | fastai.collab |
collab_learner(dls, n_factors=50) |
6. 测试时增强 (TTA)
preds, targs = learn.tta() # 对测试样本做多个增强版本取平均
accuracy(preds, targs)
TTA 通常在验证集上提升 0.5%-2% 准确率,零成本。
7. FastAI vs Keras vs PyTorch 原生
| 维度 | FastAI | Keras | PyTorch |
|---|---|---|---|
| 上手速度 | ⭐⭐⭐ | ⭐⭐ | ⭐ |
| 内置 SOTA 技巧 | ⭐⭐⭐ | ⭐ | ⭐ |
| 自定义灵活性 | ⭐⭐ | ⭐⭐ | ⭐⭐⭐ |
| 生产部署 | ⭐⭐ | ⭐⭐⭐ | ⭐⭐⭐ |
| 教材配套 | ⭐⭐⭐(免费课) | ⭐⭐ | ⭐ |
思考题
slice(lr/100, lr/10)的差分学习率为什么 head 大、backbone 小?- DataBlock API 相比手写 PyTorch Dataset 的优势和局限是什么?
- 什么时候用
fine_tune,什么时候用fit_one_cycle?