Vision Transformer(ViT)是由 Google Research 团队于 2020 年提出的一种将 Transformer 架构直接应用于图像分类的模型架构。与传统的卷积神经网络(CNN)不同,ViT 将图像分割为固定大小的补丁(patch),将每个补丁视为一个"词符"(token),然后使用标准的 Transformer Encoder 进行特征提取和分类。
ViT 的研究论文 An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale(Dosovitskiy et al., 2020)首次证明了纯 Transformer 架构在图像识别任务上可以达到甚至超越最先进的 CNN,前提是在足够大规模的数据集上进行预训练。
在 ViT 出现之前,卷积神经网络(CNN)是计算机视觉领域的事实标准。CNN 通过卷积核在图像上滑动来提取局部特征,并通过堆叠卷积层和池化层来逐步扩大感受野。而 Transformer 在 NLP 领域的巨大成功(如 BERT、GPT 系列)引发了一个自然的问题:能否将 Transformer 这种基于自注意力机制的架构直接应用于视觉任务?
ViT 的关键洞见在于:图像可以被视为一个词符序列。将图像切分为 的补丁,每个补丁展开为一个向量,这些向量就构成了 Transformer 的输入序列。这种设计完全舍弃了 CNN 中的卷积和池化操作,仅依赖自注意力机制来捕获图像中不同区域之间的长距离依赖关系。
ViT 将输入图像 (高度 、宽度 、通道数 )分割为 个大小为 的补丁:
每个补丁被展平为一个 维的向量。例如,标准 ImageNet 图像为 ,补丁大小为 (即每张图像 196 个补丁),每个补丁展平后为 维向量。
这种补丁化操作可以视为一种极其粗粒度的"卷积"——步长等于补丁大小,且每个卷积核独立工作,不进行卷积核之间的融合运算。
ViT 的整体架构由以下几个关键组件组成:
Patch Embedding 通过一个线性投影层将每个展平的补丁向量映射到 维的嵌入空间。数学上表示为:
在实际实现中,由于展平操作破坏了空间结构,通常会使用一个 的卷积层(步长 、输出通道 )来实现 Patch Embedding。这一卷积层的权重矩阵等价于线性投影 的转置,但计算效率更高。
受 BERT 中 [CLS] 词符的启发,ViT 在输入序列的开头插入一个可学习的类别词符(class token)。该词符与图像补丁词符一起输入 Transformer Encoder,经过多层编码后,其对应位置的输出向量 被用于分类预测。
为什么需要类别词符?因为 Transformer 是序列到序列的模型,其输出序列长度与输入序列相同。为了从序列中提取全局分类信息,需要一个"汇总"词符来聚合来自所有图像补丁的信息。在训练过程中,类别词符通过自注意力机制与其他补丁交互,逐步学习到能够代表整张图像的全局特征表示。
由于 Transformer 的自注意力机制本身是置换等变(permutation equivariant)的——即改变输入顺序会同等改变输出顺序,但不改变输出值——模型本身无法感知词符之间的位置关系。对于图像而言,补丁之间的空间位置关系至关重要,必须引入位置编码来提供位置信息。
ViT 使用可学习的一维位置编码(learnable 1D positional encoding)。每个位置 (, 为补丁数量,包含类别词符)对应一个可学习的 维向量 。嵌入序列加上位置编码后作为 Transformer 的输入:
其中 为位置编码矩阵。
选择一维位置编码而非二维编码的原因有二:一是参数量更少;二是实验结果表明一维编码的效果与二维编码相当。这一简化也与 Transformer 的序列建模范式一致——图像补丁被线性化为一个序列,位置编码只需要标识每个补丁在序列中的序号。
ViT 使用的 Transformer Encoder 与 Vaswani et al. (2017) 原版 Transformer 中的编码器结构相同,由 层组成。每一层包含两个子层:
每个子层前使用层归一化(LayerNorm,LN),每个子层后使用残差连接(residual connection):
其中 表示第 层。
自注意力机制计算序列中每个词符与其他所有词符的注意力权重。对于第 层,输入 :
其中 为可学习的投影矩阵。
其中每个 , 为输出投影矩阵,。
MLP 层由两个全连接层和一个 GELU(Gaussian Error Linear Unit)激活函数组成:
其中 ,。通常 。
经过 层 Transformer Encoder 后,取类别词符位置的输出向量 ,通过一个全连接分类头进行类别预测:
其中 , 为类别总数。训练时使用交叉熵损失函数。
另一种替代方案是使用全局平均池化(global average pooling)聚合所有词符的输出,但实验表明类别词符的效果更好。
ViT 论文提供了三种不同规模的模型配置:
| 模型 | 层数 | 隐层维度 | 头数 | MLP 维度 | 参数量 |
|---|---|---|---|---|---|
| ViT-Base | 12 | 768 | 12 | 3072 | 86M |
| ViT-Large | 24 | 1024 | 16 | 4096 | 307M |
| ViT-Huge | 32 | 1280 | 16 | 5120 | 632M |
ViT-Base 的参数量与 BERT-Base 相同,ViT-Large 与 BERT-Large 相同,体现了架构层面的高度统一。在补丁大小方面,大多数实验使用 ,部分实验使用 或 。
ViT 的训练分为两个阶段:大规模预训练和下游任务微调。
预训练数据:
ViT 的核心发现之一是预训练数据规模至关重要。论文在以下数据集上进行了对比实验:
当在 ImageNet-1k(100 万张图像,1000 个类别)上直接预训练时,ViT 的性能低于同等规模的 CNN(如 ResNet)。但当在 JFT-300M 上预训练后,ViT 在多个下游任务上超越了最先进的 CNN 模型。这表明:Transformer 在大数据场景下的扩展性优于 CNN,其缺少的归纳偏置(如平移不变性、局部性)可以通过大规模数据学得。
预训练超参数:
在预训练后,ViT 在目标数据集上进行微调。微调时:
在 ImageNet-1k 验证集上的 Top-1 准确率(JFT-300M 预训练后微调):
| 模型 | Top-1 准确率 |
|---|---|
| ViT-Base/16 | 79.8% |
| ViT-Large/16 | 85.8% |
| ViT-Huge/14 | 88.6% |
| BiT-L (ResNet 变体) | 87.5% |
| Noisy Student (EfficientNet) | 87.4% |
ViT-Huge/14 在 ImageNet 上达到了当时最先进的 88.55% Top-1 准确率,超越了所有 CNN 模型。其中 "/14" 表示补丁大小为 。
ViT 在超过 20 个公开图像分类数据集上进行了迁移学习评估,涵盖了不同类型的数据集:
在大多数数据集上,在大规模数据上预训练的 ViT 达到了与 BiT-L(最优 CNN 基线)相当或更优的性能。特别是在数据量较大的下游数据集上,ViT 的优势更为明显。
ViT 在训练计算效率方面具有显著优势。相比 CNN(尤其是 EfficientNet),ViT 在达到相同性能水平时所需计算量更少:
这一效率优势来源于 Transformer 的高效矩阵运算和并行计算能力。CNN 需要大量的深度可分离卷积和通道混合操作,而 Transformer 的绝大部分计算集中在矩阵乘法上,更适配 TPU 等加速器。
ViT 的自注意力机制带来了天然的可解释性。论文中展示了类别词符的多头注意力图,揭示了模型关注的不同区域:
有趣的是,ViT 即使在较低的层中也显示出了对图像结构的理解——一些注意力头学会了关注整张图像的轮廓,另一些头则关注内部纹理。这种自动学到的长距离依赖关系是 CNN 在相当深(需要很多层)后才能具备的能力。
此外,论文还发现:
ViT 与 CNN 最根本的区别在于归纳偏置的差异:
CNN 的内置归纳偏置:
ViT 的弱归纳偏置:
ViT 的设计哲学是:模型越灵活,对数据的依赖越强,但数据量足够时性能上限也越高。
| 维度 | CNN(ResNet/EfficientNet) | ViT |
|---|---|---|
| 小数据 (<10M 图像) | ✅ 更强(归纳偏置有效) | ❌ 弱(需要数据学习) |
| 大数据 (>100M 图像) | 性能受限(扩展性瓶颈) | ✅ 更强(扩展性优越) |
| 计算效率 | 中等 | ✅ 更高(2-4x) |
| 长距离依赖 | ❌ 需要深层堆叠 | ✅ 天然具备 |
| 可解释性 | 需要 Grad-CAM 等工具 | ✅ 注意力权重天然可解释 |
| 推理速度 | ✅ 更快(优化成熟) | 中等等待 |
| 参数量效率 | ✅ 更高 | 需要更多参数/数据 |
| 硬件适配 | 通用 | ✅ TPU 高效 |
ViT 论文中最重要的发现之一是 Transformer 在视觉任务上的优越扩展性。随着预训练数据的增加,ViT 的性能呈对数线性增长,且增长速度远快于 CNN。当数据量足够时,ViT 的性能曲线尚未饱和,而 CNN 已趋于平稳。这一发现预示了后续大规模视觉模型的发展方向。
ViT 扩展性的理论解释:Transformer 的自注意力机制允许模型灵活地学习数据中的任意模式,而 CNN 的固定卷积核结构对表达能力施加了约束。在大数据场景下,约束成为瓶颈;在小数据场景下,约束成为有效的正则化项。
DeiT(Touvron et al., 2021)解决了 ViT 需要巨量预训练数据的问题。核心创新包括:
DeiT 在 ImageNet 上仅用 300 个 epoch(无外部数据)就达到了 83.1%(DeiT-B)的 Top-1 准确率,而相同设置下 ViT-Base 仅约 78%。
Swin Transformer(Liu et al., 2021)引入了层次化特征金字塔和移动窗口注意力(shifted window attention):
Swin Transformer 在 ImageNet 上达到 87.3% Top-1(Swin-B),在 COCO 目标检测和 ADE20K 语义分割上大幅超越之前的 SOTA。
| 变体 | 核心创新 | 代表性论文 |
|---|---|---|
| CaiT | LayerScale(各层独立缩放因子);Class Attention(将自注意力分离为自注意力与跨注意力两阶段) | Touvron et al., 2021 |
| TNT | 内外层 Transformer:内层处理像素级细节,外层处理补丁间关系 | Han et al., 2021 |
| PVT | 金字塔结构,多阶段下采样,适配密集预测任务 | Wang et al., 2021 |
| CrossViT | 双分辨率路径(大补丁/小补丁),跨注意力融合 | Chen et al., 2021 |
| LeViT | 混合 CNN-Transformer 架构,为快速推理优化 | Graham et al., 2021 |
| MaxViT | 多轴注意力(窗口+全局)+ CNN,在全任务上高效 | Tu et al., 2022 |
| ConvNeXt | 纯 CNN 架构,借鉴 Transformer 设计原则,达到 Swin 性能 | Liu et al., 2022 |
| EfficientViT | 级联分组注意力,为高分辨率推理优化 | Cai et al., 2023 |
在 ViT 论文中,作者也探索了混合架构:
这一发现说明:混合架构在数据不足时可以提供有益的归纳偏置,但在数据充足时,纯 Transformer 架构已经能够通过数据学到 CNN 天然具备的视觉处理能力。
ViT 最直接的应用是图像分类。作为一种通用的视觉架构,ViT 在所有主流分类基准上均达到了顶尖水平。实际部署时,需要根据数据量和计算资源选择合适的变体:
ViT 的架构可以迁移到更广泛的视觉任务:
在生产环境部署 ViT 时需考虑:
nn.Linear + nn.LayerNorm 融合)加速推理import torch
import torch.nn as nn
class PatchEmbed(nn.Module):
"""图像补丁嵌入层"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
super().__init__()
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = (img_size // patch_size) ** 2
# 使用卷积实现补丁化 + 线性投影
self.proj = nn.Conv2d(in_chans, embed_dim,
kernel_size=patch_size, stride=patch_size)
def forward(self, x):
B, C, H, W = x.shape
assert H == self.img_size and W == self.img_size
x = self.proj(x) # [B, embed_dim, H/patch, W/patch]
x = x.flatten(2) # [B, embed_dim, num_patches]
x = x.transpose(1, 2) # [B, num_patches, embed_dim]
return x
class MultiHeadAttention(nn.Module):
"""多头自注意力"""
def __init__(self, embed_dim, num_heads):
super().__init__()
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.scale = self.head_dim ** -0.5
self.qkv = nn.Linear(embed_dim, embed_dim * 3)
self.proj = nn.Linear(embed_dim, embed_dim)
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
qkv = qkv.permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
return x
class MLP(nn.Module):
"""两层 MLP 使用 GELU 激活"""
def __init__(self, in_dim, hidden_dim, out_dim):
super().__init__()
self.fc1 = nn.Linear(in_dim, hidden_dim)
self.act = nn.GELU()
self.fc2 = nn.Linear(hidden_dim, out_dim)
def forward(self, x):
return self.fc2(self.act(self.fc1(x)))
class TransformerBlock(nn.Module):
"""单层 Transformer Encoder"""
def __init__(self, embed_dim, num_heads, mlp_ratio=4.0, dropout=0.1):
super().__init__()
self.norm1 = nn.LayerNorm(embed_dim)
self.attn = MultiHeadAttention(embed_dim, num_heads)
self.drop1 = nn.Dropout(dropout)
self.norm2 = nn.LayerNorm(embed_dim)
self.mlp = MLP(embed_dim, int(embed_dim * mlp_ratio), embed_dim)
self.drop2 = nn.Dropout(dropout)
def forward(self, x):
x = x + self.drop1(self.attn(self.norm1(x)))
x = x + self.drop2(self.mlp(self.norm2(x)))
return x
class VisionTransformer(nn.Module):
"""完整的 ViT 模型"""
def __init__(self, img_size=224, patch_size=16, in_chans=3,
num_classes=1000, embed_dim=768, depth=12,
num_heads=12, mlp_ratio=4.0, dropout=0.1):
super().__init__()
# Patch Embedding
self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
num_patches = self.patch_embed.num_patches
# 类别词符 + 位置编码
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
self.pos_drop = nn.Dropout(dropout)
# Transformer Encoder 堆叠
self.blocks = nn.ModuleList([
TransformerBlock(embed_dim, num_heads, mlp_ratio, dropout)
for _ in range(depth)
])
self.norm = nn.LayerNorm(embed_dim)
# 分类头
self.head = nn.Linear(embed_dim, num_classes)
# 初始化
nn.init.trunc_normal_(self.cls_token, std=0.02)
nn.init.trunc_normal_(self.pos_embed, std=0.02)
def forward(self, x):
B = x.shape[0]
x = self.patch_embed(x) # [B, N, D]
# 插入类别词符
cls_token = self.cls_token.expand(B, -1, -1)
x = torch.cat([cls_token, x], dim=1) # [B, N+1, D]
# 添加位置编码
x = x + self.pos_embed
x = self.pos_drop(x)
# Transformer Encoder
for block in self.blocks:
x = block(x)
x = self.norm(x)
# 取类别词符用于分类
x = x[:, 0]
x = self.head(x)
return x
from transformers import ViTImageProcessor, ViTForImageClassification
from PIL import Image
import requests
# 加载预训练模型和处理器
model_name = "google/vit-base-patch16-224"
processor = ViTImageProcessor.from_pretrained(model_name)
model = ViTForImageClassification.from_pretrained(model_name)
# 准备图像
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
# 预处理
inputs = processor(images=image, return_tensors="pt")
# 推理
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
predicted_class_id = logits.argmax(-1).item()
# 打印预测结果
print(f"预测类别: {model.config.id2label[predicted_class_id]}")
# 使用 timm 库中的 DeiT 实现
import timm
import torch.optim as optim
from torch.cuda.amp import autocast, GradScaler
# 创建 DeiT 模型(使用蒸馏)
model = timm.create_model('deit_base_distilled_patch16_224', pretrained=False)
teacher_model = timm.create_model('regnety_160', pretrained=True)
# 冻结教师模型
for param in teacher_model.parameters():
param.requires_grad = False
# 优化器
optimizer = optim.AdamW(model.parameters(), lr=0.0005, weight_decay=0.05)
scaler = GradScaler()
# 蒸馏损失函数
def distillation_loss(student_logits, teacher_logits, labels, alpha=0.5, temp=3.0):
"""
alpha: 真实标签与蒸馏标签的权重平衡系数
temp: 蒸馏温度
"""
ce_loss = nn.CrossEntropyLoss()(student_logits, labels)
# KL 散度蒸馏损失
soft_student = nn.functional.log_softmax(student_logits / temp, dim=-1)
soft_teacher = nn.functional.softmax(teacher_logits / temp, dim=-1)
kd_loss = nn.KLDivLoss(reduction='batchmean')(soft_student, soft_teacher) * (temp ** 2)
return (1 - alpha) * ce_loss + alpha * kd_loss
# 训练循环
for images, labels in dataloader:
optimizer.zero_grad()
with autocast():
# 前向传播
student_output = model(images)
with torch.no_grad():
teacher_output = teacher_model(images)
loss = distillation_loss(student_output, teacher_output, labels)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
ViT 的提出是计算机视觉领域的范式转变事件。在此之前,CNN 占据绝对主导地位;之后,基于 Transformer 的视觉模型成为主流方向。具体影响包括:
ViT 的核心贡献在于证明了一个简单的直觉:当数据足够时,纯 Transformer 架构可以直接处理像素级别的图像输入,无需卷积层的特殊设计。这一发现不仅带来了性能上的突破,更重要的是推动了计算机视觉与自然语言处理的架构统一,为多模态大模型时代的到来奠定了基础。
ViT 的关键启示:
从历史视角看,ViT 是计算机视觉从"手工设计特征"到"学习特征"再到"通用架构统一"这一漫长演进中的关键里程碑。它证明了在不同模态之间共享架构的可能性,这一思想正在深刻影响着整个深度学习领域的发展方向。
此页面为 AI 知识体系 的一部分,内容持续更新中。