02-functional-api-tf-data

知识库
知识库文档
/tech-stacks/tensorflow/examples/02-functional-api-tf-data.md

文档

TensorFlow Functional API + tf.data Pipeline

目标

  • 掌握 Functional API(多输入/多输出/共享层)
  • 掌握 tf.data 高性能输入管道
  • 自定义训练循环(Custom Training Loop)

完整代码

1. Functional API —— 灵活模型构建

import tensorflow as tf

# ============================================================
# Functional API 构建 ResNet 风格残差块
# ============================================================
def residual_block(x, filters, kernel_size=3, stride=1):
    """残差块:F(x) + x"""
    shortcut = x

    x = tf.keras.layers.Conv2D(filters, kernel_size, strides=stride, padding="same")(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.ReLU()(x)

    x = tf.keras.layers.Conv2D(filters, kernel_size, padding="same")(x)
    x = tf.keras.layers.BatchNormalization()(x)

    # 如果维度不匹配,调整 shortcut
    if stride != 1 or shortcut.shape[-1] != filters:
        shortcut = tf.keras.layers.Conv2D(filters, 1, strides=stride, padding="same")(shortcut)

    x = tf.keras.layers.Add()([x, shortcut])
    x = tf.keras.layers.ReLU()(x)
    return x


# 构建模型
inputs = tf.keras.Input(shape=(224, 224, 3), name="image")

x = tf.keras.layers.Conv2D(64, 7, strides=2, padding="same")(inputs)
x = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.layers.ReLU()(x)
x = tf.keras.layers.MaxPooling2D(3, strides=2, padding="same")(x)

# 堆叠残差块
x = residual_block(x, 64)
x = residual_block(x, 128, stride=2)
x = residual_block(x, 256, stride=2)

x = tf.keras.layers.GlobalAveragePooling2D()(x)
outputs = tf.keras.layers.Dense(1000, activation="softmax")(x)

model = tf.keras.Model(inputs=inputs, outputs=outputs, name="MiniResNet")
print(f"模型参数量: {model.count_params():,}")

2. 多输入/多输出模型

# 场景:商品推荐 —— 图像 + 文本 → 类别 + 价格
image_input = tf.keras.Input(shape=(224, 224, 3), name="image")
text_input = tf.keras.Input(shape=(100,), name="text_features")

# 图像分支
x_img = tf.keras.applications.MobileNetV2(
    include_top=False, weights="imagenet"
)(image_input)
x_img = tf.keras.layers.GlobalAveragePooling2D()(x_img)
x_img = tf.keras.layers.Dense(128, activation="relu")(x_img)

# 文本分支
x_text = tf.keras.layers.Dense(128, activation="relu")(text_input)
x_text = tf.keras.layers.Dense(128, activation="relu")(x_text)

# 融合
combined = tf.keras.layers.Concatenate()([x_img, x_text])
shared = tf.keras.layers.Dense(256, activation="relu")(shared := combined)
shared = tf.keras.layers.Dropout(0.3)(shared)

# 多输出
category_output = tf.keras.layers.Dense(50, activation="softmax", name="category")(shared)
price_output = tf.keras.layers.Dense(1, activation="linear", name="price")(shared)

multi_model = tf.keras.Model(
    inputs=[image_input, text_input],
    outputs=[category_output, price_output],
)

# 编译时指定不同的 loss 和 metric
multi_model.compile(
    optimizer="adam",
    loss={"category": "categorical_crossentropy", "price": "mse"},
    loss_weights={"category": 1.0, "price": 0.5},
    metrics={"category": "accuracy", "price": "mae"},
)

3. tf.data 高性能管道

# ============================================================
# 从文件构建数据集
# ============================================================
def parse_image(filename, label):
    """解析单张图片"""
    image = tf.io.read_file(filename)
    image = tf.image.decode_jpeg(image, channels=3)
    image = tf.image.resize(image, [224, 224])
    image = tf.cast(image, tf.float32) / 255.0
    return image, label


def augment(image, label):
    """数据增强"""
    image = tf.image.random_flip_left_right(image)
    image = tf.image.random_brightness(image, max_delta=0.2)
    image = tf.image.random_contrast(image, lower=0.8, upper=1.2)
    return image, label


# 构建高性能管道
AUTOTUNE = tf.data.AUTOTUNE
BATCH_SIZE = 64

# 假设有文件列表和标签
train_ds = (
    tf.data.Dataset.from_tensor_slices((train_files, train_labels))
    .shuffle(buffer_size=10000)          # 充分打乱
    .map(parse_image, num_parallel_calls=AUTOTUNE)  # 并行解析
    .map(augment, num_parallel_calls=AUTOTUNE)      # 数据增强
    .batch(BATCH_SIZE)                    # 批处理
    .prefetch(AUTOTUNE)                   # 预取(CPU 和 GPU 异步工作)
)

val_ds = (
    tf.data.Dataset.from_tensor_slices((val_files, val_labels))
    .map(parse_image, num_parallel_calls=AUTOTUNE)
    .batch(BATCH_SIZE)
    .prefetch(AUTOTUNE)
)

# 训练时直接用 tf.data Dataset
model.fit(train_ds, validation_data=val_ds, epochs=10)

4. 自定义训练循环

# ============================================================
# 当 model.fit 不够灵活时
# ============================================================
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
loss_fn = tf.keras.losses.CategoricalCrossentropy()
train_acc = tf.keras.metrics.CategoricalAccuracy()
val_acc = tf.keras.metrics.CategoricalAccuracy()

@tf.function  # 编译为计算图,加速
def train_step(images, labels):
    with tf.GradientTape() as tape:
        predictions = model(images, training=True)
        loss = loss_fn(labels, predictions)

    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

    train_acc.update_state(labels, predictions)
    return loss


@tf.function
def val_step(images, labels):
    predictions = model(images, training=False)
    val_acc.update_state(labels, predictions)


# 训练循环
for epoch in range(10):
    train_acc.reset_state()
    val_acc.reset_state()

    for batch_images, batch_labels in train_ds:
        loss = train_step(batch_images, batch_labels)

    for batch_images, batch_labels in val_ds:
        val_step(batch_images, batch_labels)

    print(f"Epoch {epoch+1}: Train Acc={train_acc.result():.3f}, "
          f"Val Acc={val_acc.result():.3f}")

关键要点

概念 说明
Functional API Input → Layer → Model(inputs, outputs)
多输入/输出 字典方式指定 loss/metric
tf.data.AUTOTUNE 自动并行度调优
.prefetch() CPU 预取下一批数据,消除 GPU 等待
@tf.function 将 Python 函数编译为计算图
tf.GradientTape 记录前向传播,用于自定义梯度计算

信息

路径
/tech-stacks/tensorflow/examples/02-functional-api-tf-data.md
更新时间
2026/5/30