TensorFlow Functional API + tf.data Pipeline
目标
- 掌握 Functional API(多输入/多输出/共享层)
- 掌握 tf.data 高性能输入管道
- 自定义训练循环(Custom Training Loop)
完整代码
1. Functional API —— 灵活模型构建
import tensorflow as tf
# ============================================================
# Functional API 构建 ResNet 风格残差块
# ============================================================
def residual_block(x, filters, kernel_size=3, stride=1):
"""残差块:F(x) + x"""
shortcut = x
x = tf.keras.layers.Conv2D(filters, kernel_size, strides=stride, padding="same")(x)
x = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.layers.ReLU()(x)
x = tf.keras.layers.Conv2D(filters, kernel_size, padding="same")(x)
x = tf.keras.layers.BatchNormalization()(x)
# 如果维度不匹配,调整 shortcut
if stride != 1 or shortcut.shape[-1] != filters:
shortcut = tf.keras.layers.Conv2D(filters, 1, strides=stride, padding="same")(shortcut)
x = tf.keras.layers.Add()([x, shortcut])
x = tf.keras.layers.ReLU()(x)
return x
# 构建模型
inputs = tf.keras.Input(shape=(224, 224, 3), name="image")
x = tf.keras.layers.Conv2D(64, 7, strides=2, padding="same")(inputs)
x = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.layers.ReLU()(x)
x = tf.keras.layers.MaxPooling2D(3, strides=2, padding="same")(x)
# 堆叠残差块
x = residual_block(x, 64)
x = residual_block(x, 128, stride=2)
x = residual_block(x, 256, stride=2)
x = tf.keras.layers.GlobalAveragePooling2D()(x)
outputs = tf.keras.layers.Dense(1000, activation="softmax")(x)
model = tf.keras.Model(inputs=inputs, outputs=outputs, name="MiniResNet")
print(f"模型参数量: {model.count_params():,}")
2. 多输入/多输出模型
# 场景:商品推荐 —— 图像 + 文本 → 类别 + 价格
image_input = tf.keras.Input(shape=(224, 224, 3), name="image")
text_input = tf.keras.Input(shape=(100,), name="text_features")
# 图像分支
x_img = tf.keras.applications.MobileNetV2(
include_top=False, weights="imagenet"
)(image_input)
x_img = tf.keras.layers.GlobalAveragePooling2D()(x_img)
x_img = tf.keras.layers.Dense(128, activation="relu")(x_img)
# 文本分支
x_text = tf.keras.layers.Dense(128, activation="relu")(text_input)
x_text = tf.keras.layers.Dense(128, activation="relu")(x_text)
# 融合
combined = tf.keras.layers.Concatenate()([x_img, x_text])
shared = tf.keras.layers.Dense(256, activation="relu")(shared := combined)
shared = tf.keras.layers.Dropout(0.3)(shared)
# 多输出
category_output = tf.keras.layers.Dense(50, activation="softmax", name="category")(shared)
price_output = tf.keras.layers.Dense(1, activation="linear", name="price")(shared)
multi_model = tf.keras.Model(
inputs=[image_input, text_input],
outputs=[category_output, price_output],
)
# 编译时指定不同的 loss 和 metric
multi_model.compile(
optimizer="adam",
loss={"category": "categorical_crossentropy", "price": "mse"},
loss_weights={"category": 1.0, "price": 0.5},
metrics={"category": "accuracy", "price": "mae"},
)
3. tf.data 高性能管道
# ============================================================
# 从文件构建数据集
# ============================================================
def parse_image(filename, label):
"""解析单张图片"""
image = tf.io.read_file(filename)
image = tf.image.decode_jpeg(image, channels=3)
image = tf.image.resize(image, [224, 224])
image = tf.cast(image, tf.float32) / 255.0
return image, label
def augment(image, label):
"""数据增强"""
image = tf.image.random_flip_left_right(image)
image = tf.image.random_brightness(image, max_delta=0.2)
image = tf.image.random_contrast(image, lower=0.8, upper=1.2)
return image, label
# 构建高性能管道
AUTOTUNE = tf.data.AUTOTUNE
BATCH_SIZE = 64
# 假设有文件列表和标签
train_ds = (
tf.data.Dataset.from_tensor_slices((train_files, train_labels))
.shuffle(buffer_size=10000) # 充分打乱
.map(parse_image, num_parallel_calls=AUTOTUNE) # 并行解析
.map(augment, num_parallel_calls=AUTOTUNE) # 数据增强
.batch(BATCH_SIZE) # 批处理
.prefetch(AUTOTUNE) # 预取(CPU 和 GPU 异步工作)
)
val_ds = (
tf.data.Dataset.from_tensor_slices((val_files, val_labels))
.map(parse_image, num_parallel_calls=AUTOTUNE)
.batch(BATCH_SIZE)
.prefetch(AUTOTUNE)
)
# 训练时直接用 tf.data Dataset
model.fit(train_ds, validation_data=val_ds, epochs=10)
4. 自定义训练循环
# ============================================================
# 当 model.fit 不够灵活时
# ============================================================
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
loss_fn = tf.keras.losses.CategoricalCrossentropy()
train_acc = tf.keras.metrics.CategoricalAccuracy()
val_acc = tf.keras.metrics.CategoricalAccuracy()
@tf.function # 编译为计算图,加速
def train_step(images, labels):
with tf.GradientTape() as tape:
predictions = model(images, training=True)
loss = loss_fn(labels, predictions)
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
train_acc.update_state(labels, predictions)
return loss
@tf.function
def val_step(images, labels):
predictions = model(images, training=False)
val_acc.update_state(labels, predictions)
# 训练循环
for epoch in range(10):
train_acc.reset_state()
val_acc.reset_state()
for batch_images, batch_labels in train_ds:
loss = train_step(batch_images, batch_labels)
for batch_images, batch_labels in val_ds:
val_step(batch_images, batch_labels)
print(f"Epoch {epoch+1}: Train Acc={train_acc.result():.3f}, "
f"Val Acc={val_acc.result():.3f}")
关键要点
| 概念 |
说明 |
| Functional API |
Input → Layer → Model(inputs, outputs) |
| 多输入/输出 |
字典方式指定 loss/metric |
tf.data.AUTOTUNE |
自动并行度调优 |
.prefetch() |
CPU 预取下一批数据,消除 GPU 等待 |
@tf.function |
将 Python 函数编译为计算图 |
tf.GradientTape |
记录前向传播,用于自定义梯度计算 |