TensorFlow Keras —— MNIST 手写数字分类
目标
- 掌握
tf.keras.Sequential 模型构建
- 理解
compile → fit → evaluate → predict 标准流程
- 使用 TensorBoard 可视化训练过程
完整代码
import tensorflow as tf
import numpy as np
from datetime import datetime
print(f"TensorFlow {tf.__version__}")
print(f"GPU: {tf.config.list_physical_devices('GPU')}")
# ============================================================
# 1. 加载与预处理数据
# ============================================================
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
# 归一化 + 添加通道维度
x_train = x_train.astype("float32") / 255.0
x_test = x_test.astype("float32") / 255.0
x_train = x_train[..., tf.newaxis] # (60000, 28, 28, 1)
x_test = x_test[..., tf.newaxis]
# One-hot 编码
y_train = tf.keras.utils.to_categorical(y_train, 10)
y_test = tf.keras.utils.to_categorical(y_test, 10)
print(f"训练集: {x_train.shape}, 测试集: {x_test.shape}")
# ============================================================
# 2. 构建模型(Sequential API)
# ============================================================
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(32, (3, 3), activation="relu", input_shape=(28, 28, 1)),
tf.keras.layers.MaxPooling2D((2, 2)),
tf.keras.layers.Conv2D(64, (3, 3), activation="relu"),
tf.keras.layers.MaxPooling2D((2, 2)),
tf.keras.layers.Conv2D(64, (3, 3), activation="relu"),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(128, activation="relu"),
tf.keras.layers.Dropout(0.5),
tf.keras.layers.Dense(10, activation="softmax"),
], name="MNIST_CNN")
model.summary()
# ============================================================
# 3. 编译模型
# ============================================================
model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
loss="categorical_crossentropy",
metrics=["accuracy"],
)
# ============================================================
# 4. 回调函数
# ============================================================
callbacks = [
# 早停
tf.keras.callbacks.EarlyStopping(
monitor="val_loss", patience=3, restore_best_weights=True
),
# 学习率衰减
tf.keras.callbacks.ReduceLROnPlateau(
monitor="val_loss", factor=0.5, patience=2, min_lr=1e-6
),
# 模型检查点
tf.keras.callbacks.ModelCheckpoint(
"best_model.keras", monitor="val_accuracy", save_best_only=True
),
# TensorBoard
tf.keras.callbacks.TensorBoard(
log_dir=f"logs/{datetime.now().strftime('%Y%m%d-%H%M%S')}",
histogram_freq=1,
),
]
# ============================================================
# 5. 训练
# ============================================================
history = model.fit(
x_train, y_train,
batch_size=128,
epochs=10,
validation_split=0.1,
callbacks=callbacks,
verbose=1,
)
# ============================================================
# 6. 评估
# ============================================================
test_loss, test_acc = model.evaluate(x_test, y_test, verbose=0)
print(f"\n✅ 测试准确率: {test_acc:.2%}")
# ============================================================
# 7. 预测
# ============================================================
predictions = model.predict(x_test[:5])
pred_classes = np.argmax(predictions, axis=1)
true_classes = np.argmax(y_test[:5], axis=1)
for i in range(5):
print(f"样本 {i}: 预测={pred_classes[i]}, 实际={true_classes[i]}, "
f"置信度={predictions[i][pred_classes[i]]:.3f}")
# ============================================================
# 8. 保存与加载
# ============================================================
model.save("mnist_cnn.keras")
print("模型已保存为 mnist_cnn.keras")
# 加载
loaded_model = tf.keras.models.load_model("mnist_cnn.keras")
loaded_model.evaluate(x_test, y_test, verbose=0)
# ============================================================
# 9. 启动 TensorBoard
# ============================================================
# 终端运行: tensorboard --logdir=logs
# 浏览器访问: http://localhost:6006
模型摘要输出
Model: "MNIST_CNN"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv2d (Conv2D) (None, 26, 26, 32) 320
max_pooling2d (MaxPooling2D) (None, 13, 13, 32) 0
conv2d_1 (Conv2D) (None, 11, 11, 64) 18,496
max_pooling2d_1 (MaxPool) (None, 5, 5, 64) 0
conv2d_2 (Conv2D) (None, 3, 3, 64) 36,928
flatten (Flatten) (None, 576) 0
dense (Dense) (None, 128) 73,856
dropout (Dropout) (None, 128) 0
dense_1 (Dense) (None, 10) 1,290
=================================================================
Total params: 130,890
训练管道速查
| 步骤 |
API |
| 构建 |
Sequential() 或 Functional API |
| 编译 |
model.compile(optimizer, loss, metrics) |
| 训练 |
model.fit(x, y, epochs, batch_size, callbacks) |
| 评估 |
model.evaluate(x_test, y_test) |
| 预测 |
model.predict(x_new) |
| 保存 |
model.save("model.keras") |
| 加载 |
tf.keras.models.load_model("model.keras") |