前言

模型训练出来只是第一步,让它稳定高效地服务线上请求才是真正的 challenge。这篇文章覆盖模型部署的完整流程——从模型导出到生产级推理服务。

Step 1:模型导出与格式转换

PyTorch → ONNX

ONNX(Open Neural Network Exchange)是模型部署的中间格式,几乎所有的推理框架和硬件都支持它。

import torch
import torch.onnx

model = YourModel()
model.eval()

dummy_input = torch.randn(1, 3, 224, 224)

torch.onnx.export(
    model,
    dummy_input,
    "model.onnx",
    input_names=["input"],
    output_names=["output"],
    dynamic_axes={
        "input": {0: "batch_size"},
        "output": {0: "batch_size"},
    },
    opset_version=17,
)

关键参数:
- dynamic_axes:声明哪些维度是可变的,比如 batch_size 和序列长度
- opset_version:越高支持的算子越多,但硬件兼容性可能下降。推荐 17-19

ONNX 优化

导出的 ONNX 模型可以用 onnxruntime 做优化:

python -m onnxruntime.tools.convert_onnx_models_to_ort model.onnx

或者在代码中优化:

import onnxruntime as ort
import onnx
from onnxruntime.transformers import optimizer

opt = optimizer.optimize_model(
    "model.onnx",
    model_type="bert",    # 根据模型类型选择
    num_heads=12,
    hidden_size=768,
    opt_level=99,         # 最大优化
)
opt.save_model_to_file("model_optimized.onnx")

优化效果:一般能提升 1.5x-3x 的推理速度。

检查 ONNX 模型的正确性

# 验证输出一致性
with torch.no_grad():
    torch_output = model(dummy_input).numpy()

ort_session = ort.InferenceSession("model.onnx")
ort_input = {ort_session.get_inputs()[0].name: dummy_input.numpy()}
ort_output = ort_session.run(None, ort_input)

# 检查误差
diff = np.abs(torch_output - ort_output[0]).max()
print(f"Max diff: {diff}")
# 对于 FP32,diff 1 时需多卡
)

Step 4:性能压测与监控

关键指标

指标 英文 含义 好值
首 Token 延迟 TTFT 从请求到第一个 Token 的时间 1000 tokens/s
并发数 Concurrency 同时处理的请求数 视场景而定

压测工具

# 用 vllm 自带的 benchmark
python -m vllm.benchmarks.benchmark_throughput     --model meta-llama/Llama-3.2-3B     --dataset ShareGPT_V3_unfiltered_cleaned_split.json     --num-prompts 1000

# 用 locust 做 HTTP 压测
pip install locust
locust --host http://localhost:8000

生产级部署架构

                   ┌─────────────┐
                   │   Load      │
                   │  Balancer   │
                   └──────┬──────┘
                          │
          ┌───────────────┼───────────────┐
          │               │               │
    ┌─────▼─────┐   ┌─────▼─────┐   ┌─────▼─────┐
    │ Worker 1  │   │ Worker 2  │   │ Worker 3  │
    │ vLLM/Triton│   │ vLLM/Triton│   │ vLLM/Triton│
    └───────────┘   └───────────┘   └───────────┘
          │               │               │
    ┌─────▼───────────────▼───────────────▼─────┐
    │          共享显存 / 模型分片                 │
    └───────────────────────────────────────────┘
  • 水平扩展:多个 Worker 加 Load Balancer
  • 动态缩扩容:根据队列长度自动调整 Worker 数量
  • 缓存层:Redis 缓存常见问题的回答(适合 LLM 场景,回答可以复用)

总结

模型部署的核心链路:

PyTorch → ONNX → TensorRT → Triton/vLLM → Load Balancer → API
  ①        ②        ③          ④               ⑤
  1. PyTorch 导出 ONNX(设置 dynamic_axes)
  2. ONNX 优化(onnxruntime transformers)
  3. TensorRT 转换(FP16 是最优性价比)
  4. 推理框架部署(LLM 用 vLLM,多模态用 Triton)
  5. 负载均衡与缓存

每个环节都有优化空间,但先跑通全链路,再逐步压测优化是最务实的方式。