TensorFlow 教程 —— Keras 进阶与模型部署
本章目标
- 掌握 Keras Functional API 和子类化 API
- 了解 TF 与 PyTorch 的核心差异
- 掌握 TF Lite 与 TF Serving 部署
1. 三种建模 API 对比
Sequential(最简单)
model = tf.keras.Sequential([
tf.keras.layers.Dense(128, activation="relu"),
tf.keras.layers.Dense(10, activation="softmax"),
])
# 适用:单输入 → 单输出,线性堆叠
Functional API(中等灵活)
inputs = tf.keras.Input(shape=(784,))
x = tf.keras.layers.Dense(128, activation="relu")(inputs)
outputs = tf.keras.layers.Dense(10, activation="softmax")(x)
model = tf.keras.Model(inputs, outputs)
# 适用:多输入/输出、共享层、残差连接、分支合并
Subclassing(最灵活)
class MyModel(tf.keras.Model):
def __init__(self):
super().__init__()
self.dense1 = tf.keras.layers.Dense(128, activation="relu")
self.dense2 = tf.keras.layers.Dense(10, activation="softmax")
def call(self, inputs, training=False):
x = self.dense1(inputs)
if training:
x = tf.keras.layers.Dropout(0.5)(x)
return self.dense2(x)
# 适用:需要动态控制流、自定义训练逻辑
# 缺点:无法自动 summary、无法直接 save/load 结构
2. TF vs PyTorch 核心差异
| 维度 |
TensorFlow |
PyTorch |
| 高级 API |
Keras(官方) |
无官方;Lightning 社区 |
| 计算图 |
静态(TF 1.x)→ Eager(TF 2.x) |
始终动态 |
| 部署 |
TF Serving / TF Lite / TF.js |
TorchServe / ONNX |
| 移动端 |
TF Lite(最成熟) |
PyTorch Mobile |
| 浏览器 |
TF.js |
有限 |
| 生产 Pipeline |
TFX(完整) |
需拼凑 |
| 学术采用 |
较少 |
占主导 |
| 行业采用 |
较多(Google 生态) |
快速增长 |
3. 模型部署
3.1 TensorFlow Lite(移动端/边缘端)
# 转换模型
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT] # 量化优化
tflite_model = converter.convert()
# 保存
with open("model.tflite", "wb") as f:
f.write(tflite_model)
# 在 Android/iOS 上加载推理(需 TF Lite 运行时)
# 详细见: https://www.tensorflow.org/lite
3.2 TensorFlow Serving(服务器部署)
model.save("mnist_model/1/", save_format="tf")
docker run -p 8501:8501 \
--mount type=bind,source=$(pwd)/mnist_model,target=/models/mnist \
-e MODEL_NAME=mnist \
tensorflow/serving
curl -X POST http:
-H "Content-Type: application/json" \
-d '{"instances": [[0.0, ...]]}'
4. TensorBoard 使用
tensorboard --logdir=logs --port=6006
# 添加更多可视化
file_writer = tf.summary.create_file_writer("logs/custom")
with file_writer.as_default():
tf.summary.scalar("custom_loss", loss.numpy(), step=epoch)
tf.summary.image("input_images", images, step=epoch, max_outputs=5)
tf.summary.histogram("weights", model.layers[0].weights[0], step=epoch)
思考题
- TF 的静态图(Graph Mode)与动态图(Eager Mode)各自优劣势?
tf.function 装饰器做了什么?它为什么能加速?
- TF Lite 量化有哪三种方式?各自的精度与速度权衡?
- 什么场景下应该选择 TF 而非 PyTorch?什么场景反过来?