Keras 入门教程:从 Sequential 到迁移学习
1. Keras 是什么?
Keras 是 TensorFlow 的官方高级 API,核心理念是"让深度学习变得简单"。它抽象了张量运算、自动微分、GPU 调度等底层细节,让开发者专注于模型架构。
三大编程范式
| 范式 |
API |
适用场景 |
| Sequential |
keras.Sequential |
层堆叠的简单模型 |
| Functional |
keras.Model |
多输入/输出、共享层、残差连接 |
| Subclassing |
class MyModel(keras.Model) |
自定义训练逻辑、研究实验 |
2. Sequential 入门
model = keras.Sequential([
layers.Dense(64, activation='relu', input_shape=(784,)),
layers.Dense(64, activation='relu'),
layers.Dense(10, activation='softmax'),
])
Sequential 假设"输入→层1→层2→...→输出"的线性拓扑。input_shape 仅需在第一层指定。
3. Functional API:多输入多输出
inputs = keras.Input(shape=(784,))
x = layers.Dense(64, activation='relu')(inputs)
x = layers.Dense(64, activation='relu')(x)
outputs = layers.Dense(10, activation='softmax')(x)
model = keras.Model(inputs=inputs, outputs=outputs)
Functional API 的核心:层是函数,接受张量返回张量。这让你可以构建 DAG(有向无环图)——跳过连接、分支、合并等复杂拓扑。
4. Subclassing:完全控制
class MyModel(keras.Model):
def __init__(self):
super().__init__()
self.dense1 = layers.Dense(64, activation='relu')
self.dense2 = layers.Dense(10, activation='softmax')
def call(self, inputs):
x = self.dense1(inputs)
return self.dense2(x)
Subclassing 适合研究场景,但不会自动暴露层结构——model.summary() 需先 build。
5. 回调(Callbacks)与训练监控
callbacks = [
keras.callbacks.EarlyStopping(patience=3, restore_best_weights=True),
keras.callbacks.ModelCheckpoint('best.keras', save_best_only=True),
keras.callbacks.ReduceLROnPlateau(factor=0.5, patience=2),
]
model.fit(x_train, y_train, epochs=50, callbacks=callbacks, validation_split=0.1)
推荐底线配置:EarlyStopping + ModelCheckpoint 至少这两个。
6. 迁移学习实战(Keras 3.0)
# 加载预训练模型
base_model = keras.applications.EfficientNetV2B0(
include_top=False,
weights='imagenet',
input_shape=(224, 224, 3)
)
base_model.trainable = False # 冻结
# 添加自定义分类头
inputs = keras.Input(shape=(224, 224, 3))
x = base_model(inputs, training=False)
x = layers.GlobalAveragePooling2D()(x)
x = layers.Dropout(0.2)(x)
outputs = layers.Dense(10, activation='softmax')(x)
model = keras.Model(inputs, outputs)
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
先用冻结 backbone 训练分类头(5-10 epoch),再解冻微调(学习率降低 10x)。
7. 常见坑与调试
| 问题 |
原因 |
解决 |
| shape mismatch |
Dense 输出与标签维度不一致 |
确认最后一层 units == num_classes |
| loss 不下降 |
学习率太大/太小 |
从 1e-3 开始,使用 ReduceLROnPlateau |
| OOM |
batch_size 太大 |
减小到 32 或 16 |
| 训练过拟合 |
数据少/无正则化 |
加 Dropout + 数据增强 |
思考题
- Functional API 相比 Sequential 的不可替代场景有哪些?
BatchNormalization 在 activation 之前还是之后?为什么?
- 迁移学习中为什么先冻结 backbone → 再微调,而不是直接全部训练?