AI 开发框架是构建、训练、部署和推理人工智能模型的软件基础设施。自 2015 年 TensorFlow 开源以来,框架生态经历了从学术研究工具到工业级生产平台的演进。本索引提供主流框架的概览、选型对比和使用指南。
| 分类 |
代表框架 |
主要用途 |
| 深度学习训练 |
PyTorch, TensorFlow, JAX |
模型定义、自动微分、分布式训练 |
| 推理优化 |
TensorRT, ONNX Runtime, vLLM |
模型压缩、量化、加速推理 |
| LLM 服务 |
vLLM, TGI, llama.cpp |
大语言模型的高吞吐部署 |
| 分布式训练 |
DeepSpeed, Megatron-LM, FSDP |
多卡/多机并行训练 |
| 生态平台 |
Hugging Face Transformers |
预训练模型加载、微调、分享 |
| 自动机器学习 |
AutoML, Optuna, Ray Tune |
超参数搜索、架构搜索 |
| 世代 |
时期 |
代表 |
核心特征 |
| 第一代 |
2013—2015 |
Theano, Caffe |
静态图、学术研究工具 |
| 第二代 |
2015—2018 |
TensorFlow 1.x, CNTK, MXNet |
工业级部署、静态图+Session |
| 第三代 |
2018—2022 |
PyTorch, TensorFlow 2.x |
动态图优先、Eager Execution、eager debugging |
| 第四代 |
2022—至今 |
JAX, PyTorch 2.x, Mojo |
函数式编程、JIT 编译、GPU 编程语言 |
PyTorch 由 Facebook AI Research(现 Meta AI)于 2016 年发布,采用动态计算图(Define-by-Run)设计,允许在运行时动态构建计算图,极大降低了调试复杂度。
| 组件 |
功能 |
关键 API |
| Tensor |
核心数据结构(类似 NumPy ndarray,但支持 GPU) |
torch.tensor(), torch.randn() |
| Autograd |
自动微分引擎,记录计算图并通过链式法则求导 |
tensor.backward(), torch.autograd |
| nn.Module |
神经网络模块化基类 |
torch.nn.Module, torch.nn.Linear |
| Optim |
优化器实现 |
torch.optim.SGD, torch.optim.Adam |
| DataLoader |
数据加载与批处理 |
torch.utils.data.DataLoader |
| TorchScript |
模型序列化与部署格式 |
torch.jit.script(), torch.jit.trace() |
1. 定义模型类(继承 nn.Module)
→ 在 __init__ 中声明层
→ 在 forward 中定义前向传播逻辑
2. 加载数据
→ 自定义 Dataset 类
→ DataLoader 封装(batch_size, shuffle)
3. 配置训练
→ 损失函数:CrossEntropyLoss, MSELoss
→ 优化器:Adam(lr=0.001)
→ 学习率调度:StepLR, CosineAnnealing
4. 训练循环(每个 epoch)
→ 前向:outputs = model(inputs)
→ 损失:loss = criterion(outputs, labels)
→ 反向:loss.backward()
→ 参数更新:optimizer.step()
→ 梯度清零:optimizer.zero_grad()
5. 评估与保存
→ torch.no_grad() 推理
→ torch.save(model.state_dict(), "model.pth")
| PyTorch 版本 |
发布时间 |
关键特性 |
| 1.0 |
2018-12 |
稳定版发布,TorchScript 支持 |
| 1.7 |
2020-10 |
CUDA 11 支持,自定义 C++ 算子 |
| 1.9 |
2021-06 |
TorchDataloader 性能优化 |
| 1.11 |
2022-03 |
Functorch(函数式 API)、GPU 训练优化 |
| 1.13 |
2022-10 |
BetterTransformer、TorchAudio 更新 |
| 2.0 |
2023-03 |
torch.compile (JIT 编译)、加速 30%—200% |
| 2.1 |
2023-10 |
torch.distributed 增强、CPU 推理优化 |
| 2.5 |
2024-10 |
FlexAttention、torch.export 稳定版 |
| 模型 |
Eager 模式 |
torch.compile |
加速比 |
| ResNet-50 |
320 img/s |
550 img/s |
1.72× |
| ViT-B/16 |
180 img/s |
350 img/s |
1.94× |
| GPT-2 (1.5B) |
120 tok/s |
220 tok/s |
1.83× |
| LLaMA-7B |
35 tok/s |
58 tok/s |
1.66× |
| Stable Diffusion |
4.2 it/s |
7.8 it/s |
1.86× |
数据来源:PyTorch 2.0 官方 benchmark(A100-80G GPU),batch_size=32。
TensorFlow 于 2015 年由 Google Brain 开源,经历了从静态图(1.x)到动态图优先(2.x)的重大转型。
| 版本 |
核心机制 |
优缺点 |
| 1.x |
静态图(Graph)+ Session |
部署高效但调试困难 |
| 2.x |
Eager Execution(动态图) |
调试友好,Keras 为高级 API |
TensorFlow 2.x 以 Keras 作为官方高级 API,提供三种模型构建方式:
| 方式 |
适用场景 |
示例 |
| Sequential |
简单顺序模型 |
tf.keras.Sequential([Dense(64), Dense(10)]) |
| Functional API |
多输入/多输出模型 |
tf.keras.Model(inputs=x, outputs=z) |
| Subclassing |
自定义复杂模型 |
继承 tf.keras.Model 重写 call |
| 组件 |
用途 |
替代方案 |
| TFX |
生产级 ML 流水线 |
Kubeflow |
| TF Serving |
模型部署服务器 |
Triton, TorchServe |
| TF Lite |
移动端/嵌入式部署 |
Core ML, ONNX Runtime |
| TF.js |
浏览器端推理 |
ONNX.js, WebDNN |
| TPU |
专用硬件加速 |
GPU, Habana Gaudi |
JAX 是 Google 于 2018 年推出的函数式深度学习框架,以其强大的自动微分和 XLA(Accelerated Linear Algebra)JIT 编译能力著称。
| 特性 |
说明 |
| 函数式编程 |
纯函数 + 不可变数组(无 in-place 操作) |
| 自动微分 |
jax.grad(), jax.jacobian(), jax.hessian() |
| JIT 编译 |
jax.jit() 将 Python 函数编译为 XLA HLO |
| 自动向量化 |
jax.vmap() 自动批量化 |
| 并行化 |
jax.pmap() 多设备并行 |
| 反向模式自动微分 |
支持 grad、value_and_grad、grad 链 |
| 维度 |
JAX |
PyTorch |
| 编程范式 |
函数式(纯函数) |
面向对象(nn.Module) |
| 计算图 |
函数变换(jit/grad/vmap) |
Autograd 自动记录 |
| 随机数 |
显式 PRNG key 状态 |
隐式全局种子 |
| 控制流 |
JIT 内需要 jax.lax.cond/while_loop |
Python 原生 if/for |
| 社区生态 |
Flax, Haiku, Optax |
Hugging Face, Lightning |
| 易用性 |
陡峭学习曲线 |
直观易用 |
Hugging Face 是目前最大的预训练模型生态平台,Transformers 库支持 PyTorch、TensorFlow 和 JAX 三后端。
| 能力 |
支持的模型数 |
典型 API |
| 文本分类 |
300+ |
pipeline('sentiment-analysis') |
| 文本生成 |
200+ |
pipeline('text-generation') |
| 翻译 |
150+ |
pipeline('translation_en_to_zh') |
| 摘要 |
100+ |
pipeline('summarization') |
| 问答 |
200+ |
pipeline('question-answering') |
| 图像分类 |
100+ |
pipeline('image-classification') |
| 语音识别 |
80+ |
pipeline('automatic-speech-recognition') |
| 指标 |
2022 |
2023 |
2024 |
| 模型总数 |
150K |
400K |
800K+ |
| 数据集数 |
20K |
50K |
100K+ |
| GitHub Stars |
80K |
110K |
140K+ |
| 月活跃开发者 |
500K |
1.2M |
2M+ |
| 后端 |
占比 |
典型用例 |
| PyTorch |
85% |
研究、微调、训练 |
| TensorFlow |
10% |
已有 TF 基础设施的公司 |
| JAX/Flax |
5% |
Google 团队、高性能需求 |
数据来源:Hugging Face 2024 年度报告。
DeepSpeed 由微软于 2020 年开源,是训练大规模模型的核心工具。
| 技术 |
原理 |
内存节省 |
| ZeRO Stage 1 |
优化器状态分片 |
4× |
| ZeRO Stage 2 |
梯度分片 |
8× |
| ZeRO Stage 3 |
参数分片 |
64× |
| ZeRO-Offload |
CPU/NVMe 卸载 |
无限(受限于 CPU 内存) |
| ZeRO-Infinity |
结合 offload + 带宽优化 |
可在单 GPU 训练 万亿参数模型 |
ZeRO Stage 3 内存节省示例:
假设训练 175B GPT-3 模型(约 350GB 参数),混合精度(Adam 状态约 700GB + 梯度 350GB + 参数 350GB = 约 1.4TB):
| 配置 |
GPU 内存需求 |
所需 GPU 数(80GB A100) |
| 无 ZeRO |
~1,400 GB |
18 |
| ZeRO Stage 1 |
~1,050 GB |
14 |
| ZeRO Stage 2 |
~700 GB |
9 |
| ZeRO Stage 3 |
~350 GB |
5 |
| ZeRO Stage 3 + Offload |
~50 GB |
1 |
Megatron-LM 专注于张量并行和序列并行:
| 并行策略 |
概念 |
扩展效果 |
| 张量并行(TP) |
将单个 Transformer 层切分到多 GPU |
减少层内显存,降低延迟 |
| 流水线并行(PP) |
将不同层分配到不同 GPU |
支持更深的模型 |
| 序列并行(SP) |
将序列维度切分 |
支持超长上下文训练 |
3D 并行(TP + PP + DP)效果:
| 模型大小 |
GPU 数 |
3D 并行配置 |
吞吐量(TFLOPs/GPU) |
| GPT-3 175B |
1024 A100 |
TP=8, PP=8, DP=16 |
145 TFLOPs |
| LLaMA-2 70B |
64 A100 |
TP=4, PP=4, DP=4 |
168 TFLOPs |
| BLOOM 176B |
384 A100 |
TP=4, PP=8, DP=12 |
150 TFLOPs |
数据来源:NVIDIA Megatron-LM 论文及开源报告。
vLLM 是专为大语言模型推理优化的框架,核心创新为 PagedAttention:
| 技术 |
解决的问题 |
效果 |
| PagedAttention |
KV Cache 内存碎片 |
内存利用率 95%+ |
| 连续批处理 |
请求排队等待 |
吞吐量 2—4× 提升 |
| 量化支持 |
模型压缩 |
FP16→INT4 显存减 75% |
吞吐量对比(LLaMA-7B, A100-80G):
| 框架 |
请求/秒 |
相对于 Hugging Face |
| Hugging Face |
12 |
1× |
| vLLM |
48 |
4× |
| TGI |
36 |
3× |
| llama.cpp (4-bit) |
28 |
2.3× |
数据来源:vLLM GitHub README(2024 年 benchmark)。
TensorRT-LLM 是 NVIDIA 的 LLM 推理引擎,集成多种优化技术:
| 优化技术 |
技术说明 |
加速比 |
| FP8 量化 |
8-bit 浮点推理 |
2× vs FP16 |
| INT4 AWQ |
4-bit 权重量化 |
4× vs FP16 |
| In-flight Batching |
动态批处理 |
1.5—2× |
| 多节点推理 |
Tensor 并行 + 流水线并行 |
线性扩展 |
| 页面注意力 |
类似 PagedAttention |
1.5—2× |
| 需求场景 |
推荐框架 |
理由 |
| CV 研究快速实验 |
PyTorch |
动态图、丰富社区、Easy Debug |
| NLP/LLM 研究 |
PyTorch + HF Transformers |
Hugging Face 生态最完善 |
| 工业级部署 |
TensorFlow + TF Serving |
成熟部署方案、TPU 支持 |
| 高吞吐 LLM 服务 |
vLLM + TensorRT-LLM |
PagedAttention、连续批处理 |
| 万亿参数训练 |
DeepSpeed + Megatron-LM |
ZeRO、3D 并行 |
| 函数式/学术实验 |
JAX + Flax |
高度灵活、XLA 加速 |
| 移动端部署 |
TensorFlow Lite |
最广泛的移动端支持 |
| HPC / 科学计算 |
JAX |
自动微分 + 高性能计算 |
| 生成式 AI 应用 |
PyTorch + Diffusers |
Stable Diffusion 生态 |
初学者(1—3 个月)
├── PyTorch + Hugging Face Transformers
│ └── 目标:理解训练循环、模型加载、微调
└── TensorFlow + Keras
└── 目标:理解生产部署流程
进阶(3—6 个月)
├── JAX + Flax
│ └── 目标:函数式编程、JIT 编译、自动向量化
├── DeepSpeed / FSDP
│ └── 目标:分布式训练
└── vLLM / TensorRT-LLM
└── 目标:推理优化
高级(6—12 个月)
├── 自定义 CUDA 算子、Triton 语言
├── 模型并行策略设计(TP/PP/SP/EP)
└── 训练框架底层优化(通信压缩、内存管理)
| 框架 |
吞吐量 (img/s) |
峰值显存 (GB) |
代码行数 |
| PyTorch |
1,820 |
8.2 |
~80 |
| TensorFlow 2.x |
1,690 |
7.9 |
~95 |
| JAX |
2,150 |
6.8 |
~60 |
| Keras |
1,650 |
8.0 |
~50 |
性能数据来自 MLPerf Training v3.0 公开结果。
- 统一训练 + 推理框架:PyTorch 2.x 的 torch.compile 和 torch.export 正在模糊训练与推理的界限
- JavaScript AI 框架崛起:Transformers.js、ONNX Runtime Web 使浏览器端推理成为可能
- 模型压缩集成:量化和蒸馏已成为框架标配,而非事后优化
- 多模态原生支持:框架原生 API 支持文本+图像+语音多模态输入
- 编译导向:从解释执行转向 JIT/AOT 编译(torch.compile, JAX, TVM)
- 异构计算:CPU + GPU + NPU + TPU 统一编程模型
| 年份 |
上升趋势 |
下降趋势 |
| 2023 |
PyTorch 2.0, JAX |
TensorFlow 1.x |
| 2024 |
vLLM, TensorRT-LLM, Mojo |
Caffe, Theano |
| 2025 |
编译型框架, Mojo, MLX (Apple) |
纯 Python 训练循环 |
此页面为 AI 知识体系 的一部分,内容持续更新中。