TensorFlow Keras —— MNIST 手写数字分类
目标
- 掌握
tf.keras.Sequential 模型构建
- 理解
compile → fit → evaluate → predict 标准流程
- 使用 TensorBoard 可视化训练过程
完整代码
import tensorflow as tf
import numpy as np
from datetime import datetime
print(f"TensorFlow {tf.__version__}")
print(f"GPU: {tf.config.list_physical_devices('GPU')}")
# ============================================================
# 1. 加载与预处理数据
# ============================================================
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
# 归一化 + 添加通道维度
x_train = x_train.astype("float32") / 255.0
x_test = x_test.astype("float32") / 255.0
x_train = x_train[..., tf.newaxis] # (60000, 28, 28, 1)
x_test = x_test[..., tf.newaxis]
# One-hot 编码
y_train = tf.keras.utils.to_categorical(y_train, 10)
y_test = tf.keras.utils.to_categorical(y_test, 10)
print(f"训练集: {x_train.shape}, 测试集: {x_test.shape}")
# ============================================================
# 2. 构建模型(Sequential API)
# ============================================================
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(32, (3, 3), activation="relu", input_shape=(28, 28, 1)),
tf.keras.layers.MaxPooling2D((2, 2)),
tf.keras.layers.Conv2D(64, (3, 3), activation="relu"),
tf.keras.layers.MaxPooling2D((2, 2)),
tf.keras.layers.Conv2D(64, (3, 3), activation="relu"),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(128, activation="relu"),
tf.keras.layers.Dropout(0.5),
tf.keras.layers.Dense(10, activation="softmax"),
], name="MNIST_CNN")
model.summary()
# ============================================================
# 3. 编译模型
# ============================================================
model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
loss="categorical_crossentropy",
metrics=["accuracy"],
)
# ============================================================
# 4. 回调函数
# ============================================================
callbacks = [
# 早停
tf.keras.callbacks.EarlyStopping(
monitor="val_loss", patience=3, restore_best_weights=True
),
# 学习率衰减
tf.keras.callbacks.ReduceLROnPlateau(
monitor="val_loss", factor=0.5, patience=2, min_lr=1e-6
),
# 模型检查点
tf.keras.callbacks.ModelCheckpoint(
"best_model.keras", monitor="val_accuracy", save_best_only=True
),
# TensorBoard
tf.keras.callbacks.TensorBoard(
log_dir=f"logs/{datetime.now().strftime('%Y%m%d-%H%M%S')}",
histogram_freq=1,
),
]
# ============================================================
# 5. 训练
# ============================================================
history = model.fit(
x_train, y_train,
batch_size=128,
epochs=10,
validation_split=0.1,
callbacks=callbacks,
verbose=1,
)
# ============================================================
# 6. 评估
# ============================================================
test_loss, test_acc = model.evaluate(x_test, y_test, verbose=0)
print(f"\n✅ 测试准确率: {test_acc:.2%}")
# ============================================================
# 7. 预测
# ============================================================
predictions = model.predict(x_test[:5])
pred_classes = np.argmax(predictions, axis=1)
true_classes = np.argmax(y_test[:5], axis=1)
for i in range(5):
print(f"样本 {i}: 预测={pred_classes[i]}, 实际={true_classes[i]}, "
f"置信度={predictions[i][pred_classes[i]]:.3f}")
# ============================================================
# 8. 保存与加载
# ============================================================
model.save("mnist_cnn.keras")
print("模型已保存为 mnist_cnn.keras")
# 加载
loaded_model = tf.keras.models.load_model("mnist_cnn.keras")
loaded_model.evaluate(x_test, y_test, verbose=0)
# ============================================================
# 9. 启动 TensorBoard
# ============================================================
# 终端运行: tensorboard --logdir=logs
# 浏览器访问: http://localhost:6006
模型摘要输出
Model: "MNIST_CNN"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv2d (Conv2D) (None, 26, 26, 32) 320
max_pooling2d (MaxPooling2D) (None, 13, 13, 32) 0
conv2d_1 (Conv2D) (None, 11, 11, 64) 18,496
max_pooling2d_1 (MaxPool) (None, 5, 5, 64) 0
conv2d_2 (Conv2D) (None, 3, 3, 64) 36,928
flatten (Flatten) (None, 576) 0
dense (Dense) (None, 128) 73,856
dropout (Dropout) (None, 128) 0
dense_1 (Dense) (None, 10) 1,290
=================================================================
Total params: 130,890
训练管道速查
| 步骤 |
API |
| 构建 |
Sequential() 或 Functional API |
| 编译 |
model.compile(optimizer, loss, metrics) |
| 训练 |
model.fit(x, y, epochs, batch_size, callbacks) |
| 评估 |
model.evaluate(x_test, y_test) |
| 预测 |
model.predict(x_new) |
| 保存 |
model.save("model.keras") |
| 加载 |
tf.keras.models.load_model("model.keras") |
TensorFlow Functional API + tf.data Pipeline
目标
- 掌握 Functional API(多输入/多输出/共享层)
- 掌握 tf.data 高性能输入管道
- 自定义训练循环(Custom Training Loop)
完整代码
1. Functional API —— 灵活模型构建
import tensorflow as tf
# ============================================================
# Functional API 构建 ResNet 风格残差块
# ============================================================
def residual_block(x, filters, kernel_size=3, stride=1):
"""残差块:F(x) + x"""
shortcut = x
x = tf.keras.layers.Conv2D(filters, kernel_size, strides=stride, padding="same")(x)
x = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.layers.ReLU()(x)
x = tf.keras.layers.Conv2D(filters, kernel_size, padding="same")(x)
x = tf.keras.layers.BatchNormalization()(x)
# 如果维度不匹配,调整 shortcut
if stride != 1 or shortcut.shape[-1] != filters:
shortcut = tf.keras.layers.Conv2D(filters, 1, strides=stride, padding="same")(shortcut)
x = tf.keras.layers.Add()([x, shortcut])
x = tf.keras.layers.ReLU()(x)
return x
# 构建模型
inputs = tf.keras.Input(shape=(224, 224, 3), name="image")
x = tf.keras.layers.Conv2D(64, 7, strides=2, padding="same")(inputs)
x = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.layers.ReLU()(x)
x = tf.keras.layers.MaxPooling2D(3, strides=2, padding="same")(x)
# 堆叠残差块
x = residual_block(x, 64)
x = residual_block(x, 128, stride=2)
x = residual_block(x, 256, stride=2)
x = tf.keras.layers.GlobalAveragePooling2D()(x)
outputs = tf.keras.layers.Dense(1000, activation="softmax")(x)
model = tf.keras.Model(inputs=inputs, outputs=outputs, name="MiniResNet")
print(f"模型参数量: {model.count_params():,}")
2. 多输入/多输出模型
# 场景:商品推荐 —— 图像 + 文本 → 类别 + 价格
image_input = tf.keras.Input(shape=(224, 224, 3), name="image")
text_input = tf.keras.Input(shape=(100,), name="text_features")
# 图像分支
x_img = tf.keras.applications.MobileNetV2(
include_top=False, weights="imagenet"
)(image_input)
x_img = tf.keras.layers.GlobalAveragePooling2D()(x_img)
x_img = tf.keras.layers.Dense(128, activation="relu")(x_img)
# 文本分支
x_text = tf.keras.layers.Dense(128, activation="relu")(text_input)
x_text = tf.keras.layers.Dense(128, activation="relu")(x_text)
# 融合
combined = tf.keras.layers.Concatenate()([x_img, x_text])
shared = tf.keras.layers.Dense(256, activation="relu")(shared := combined)
shared = tf.keras.layers.Dropout(0.3)(shared)
# 多输出
category_output = tf.keras.layers.Dense(50, activation="softmax", name="category")(shared)
price_output = tf.keras.layers.Dense(1, activation="linear", name="price")(shared)
multi_model = tf.keras.Model(
inputs=[image_input, text_input],
outputs=[category_output, price_output],
)
# 编译时指定不同的 loss 和 metric
multi_model.compile(
optimizer="adam",
loss={"category": "categorical_crossentropy", "price": "mse"},
loss_weights={"category": 1.0, "price": 0.5},
metrics={"category": "accuracy", "price": "mae"},
)
3. tf.data 高性能管道
# ============================================================
# 从文件构建数据集
# ============================================================
def parse_image(filename, label):
"""解析单张图片"""
image = tf.io.read_file(filename)
image = tf.image.decode_jpeg(image, channels=3)
image = tf.image.resize(image, [224, 224])
image = tf.cast(image, tf.float32) / 255.0
return image, label
def augment(image, label):
"""数据增强"""
image = tf.image.random_flip_left_right(image)
image = tf.image.random_brightness(image, max_delta=0.2)
image = tf.image.random_contrast(image, lower=0.8, upper=1.2)
return image, label
# 构建高性能管道
AUTOTUNE = tf.data.AUTOTUNE
BATCH_SIZE = 64
# 假设有文件列表和标签
train_ds = (
tf.data.Dataset.from_tensor_slices((train_files, train_labels))
.shuffle(buffer_size=10000) # 充分打乱
.map(parse_image, num_parallel_calls=AUTOTUNE) # 并行解析
.map(augment, num_parallel_calls=AUTOTUNE) # 数据增强
.batch(BATCH_SIZE) # 批处理
.prefetch(AUTOTUNE) # 预取(CPU 和 GPU 异步工作)
)
val_ds = (
tf.data.Dataset.from_tensor_slices((val_files, val_labels))
.map(parse_image, num_parallel_calls=AUTOTUNE)
.batch(BATCH_SIZE)
.prefetch(AUTOTUNE)
)
# 训练时直接用 tf.data Dataset
model.fit(train_ds, validation_data=val_ds, epochs=10)
4. 自定义训练循环
# ============================================================
# 当 model.fit 不够灵活时
# ============================================================
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
loss_fn = tf.keras.losses.CategoricalCrossentropy()
train_acc = tf.keras.metrics.CategoricalAccuracy()
val_acc = tf.keras.metrics.CategoricalAccuracy()
@tf.function # 编译为计算图,加速
def train_step(images, labels):
with tf.GradientTape() as tape:
predictions = model(images, training=True)
loss = loss_fn(labels, predictions)
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
train_acc.update_state(labels, predictions)
return loss
@tf.function
def val_step(images, labels):
predictions = model(images, training=False)
val_acc.update_state(labels, predictions)
# 训练循环
for epoch in range(10):
train_acc.reset_state()
val_acc.reset_state()
for batch_images, batch_labels in train_ds:
loss = train_step(batch_images, batch_labels)
for batch_images, batch_labels in val_ds:
val_step(batch_images, batch_labels)
print(f"Epoch {epoch+1}: Train Acc={train_acc.result():.3f}, "
f"Val Acc={val_acc.result():.3f}")
关键要点
| 概念 |
说明 |
| Functional API |
Input → Layer → Model(inputs, outputs) |
| 多输入/输出 |
字典方式指定 loss/metric |
tf.data.AUTOTUNE |
自动并行度调优 |
.prefetch() |
CPU 预取下一批数据,消除 GPU 等待 |
@tf.function |
将 Python 函数编译为计算图 |
tf.GradientTape |
记录前向传播,用于自定义梯度计算 |