01-Keras与部署实战

知识库
知识库文档
/tech-stacks/tensorflow/tutorial/01-Keras与部署实战.md

文档

TensorFlow 教程 —— Keras 进阶与模型部署

本章目标

  • 掌握 Keras Functional API 和子类化 API
  • 了解 TF 与 PyTorch 的核心差异
  • 掌握 TF Lite 与 TF Serving 部署

1. 三种建模 API 对比

Sequential(最简单)

model = tf.keras.Sequential([
    tf.keras.layers.Dense(128, activation="relu"),
    tf.keras.layers.Dense(10, activation="softmax"),
])
# 适用:单输入 → 单输出,线性堆叠

Functional API(中等灵活)

inputs = tf.keras.Input(shape=(784,))
x = tf.keras.layers.Dense(128, activation="relu")(inputs)
outputs = tf.keras.layers.Dense(10, activation="softmax")(x)
model = tf.keras.Model(inputs, outputs)
# 适用:多输入/输出、共享层、残差连接、分支合并

Subclassing(最灵活)

class MyModel(tf.keras.Model):
    def __init__(self):
        super().__init__()
        self.dense1 = tf.keras.layers.Dense(128, activation="relu")
        self.dense2 = tf.keras.layers.Dense(10, activation="softmax")

    def call(self, inputs, training=False):
        x = self.dense1(inputs)
        if training:
            x = tf.keras.layers.Dropout(0.5)(x)
        return self.dense2(x)

# 适用:需要动态控制流、自定义训练逻辑
# 缺点:无法自动 summary、无法直接 save/load 结构

2. TF vs PyTorch 核心差异

维度 TensorFlow PyTorch
高级 API Keras(官方) 无官方;Lightning 社区
计算图 静态(TF 1.x)→ Eager(TF 2.x) 始终动态
部署 TF Serving / TF Lite / TF.js TorchServe / ONNX
移动端 TF Lite(最成熟) PyTorch Mobile
浏览器 TF.js 有限
生产 Pipeline TFX(完整) 需拼凑
学术采用 较少 占主导
行业采用 较多(Google 生态) 快速增长

3. 模型部署

3.1 TensorFlow Lite(移动端/边缘端)

# 转换模型
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]  # 量化优化
tflite_model = converter.convert()

# 保存
with open("model.tflite", "wb") as f:
    f.write(tflite_model)

# 在 Android/iOS 上加载推理(需 TF Lite 运行时)
# 详细见: https://www.tensorflow.org/lite

3.2 TensorFlow Serving(服务器部署)

# 保存模型为 SavedModel 格式
model.save("mnist_model/1/", save_format="tf")

# Docker 启动 TF Serving
docker run -p 8501:8501 \
  --mount type=bind,source=$(pwd)/mnist_model,target=/models/mnist \
  -e MODEL_NAME=mnist \
  tensorflow/serving

# REST API 调用
curl -X POST http://localhost:8501/v1/models/mnist:predict \
  -H "Content-Type: application/json" \
  -d '{"instances": [[0.0, ...]]}'

4. TensorBoard 使用

tensorboard --logdir=logs --port=6006
# 添加更多可视化
file_writer = tf.summary.create_file_writer("logs/custom")

with file_writer.as_default():
    tf.summary.scalar("custom_loss", loss.numpy(), step=epoch)
    tf.summary.image("input_images", images, step=epoch, max_outputs=5)
    tf.summary.histogram("weights", model.layers[0].weights[0], step=epoch)

思考题

  1. TF 的静态图(Graph Mode)与动态图(Eager Mode)各自优劣势?
  2. tf.function 装饰器做了什么?它为什么能加速?
  3. TF Lite 量化有哪三种方式?各自的精度与速度权衡?
  4. 什么场景下应该选择 TF 而非 PyTorch?什么场景反过来?

信息

路径
/tech-stacks/tensorflow/tutorial/01-Keras与部署实战.md
更新时间
2026/5/30