Vision Transformer 是什麼?為何你需要了解它
Vision Transformer(ViT)是一種將 Transformer 架構應用於電腦視覺任務的深度學習模型。傳統 CNN(卷積神經網路)依賴局部特徵提取,而 ViT 將影像切分為多個小區塊(patches),將每個區塊視為一個「詞彙」,透過自注意力機制(Self-Attention)學習全域關聯。
ViT 的核心優勢在於:可擴展性強、訓練效率高、在大型資料集上表現優異。2020 年 Google 發表的論文「An Image is Worth 16x16 Words」正式開啟了 Transformer 進軍電腦視覺的時代。
ViT 架構解析:從輸入到輸出
ViT 的運作流程可分為以下步驟:
- 影像切塊(Patch Embedding):將輸入影像(例如 224×224)切割成固定大小的區塊(如 16×16),每個區塊展平為向量。
- 線性投射(Linear Projection):透過線性層將每個區塊向量映射到隱藏維度空間。
- 位置編碼(Position Embedding):加入位置資訊,讓模型理解各區塊的空間關係。
- CLS token:附加一個額外的分類 token,用於最終分類任務。
- Transformer Encoder:透過多層 Encoder 堆疊,計算自注意力並輸出特徵。
- 分類頭(Classification Head):取 CLS token 輸出,透過線性層進行分類。
如何使用 PyTorch 載入與預訓練 ViT
以下示範如何使用 timm 庫快速載入預訓練的 ViT 模型:
import timm
import torch
# 列出可用的 ViT 模型
models = timm.list_models('vit*')
print(models[:10])
# 載入預訓練的 ViT-B/16
model = timm.create_model('vit_base_patch16_224', pretrained=True)
model.eval()
# 準備輸入影像
from PIL import Image
import torchvision.transforms as transforms
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
# 推論範例
img = Image.open('sample.jpg')
input_tensor = transform(img).unsqueeze(0)
with torch.no_grad():
output = model(input_tensor)
predicted_class = output.argmax(dim=1)
print(f"預測類別: {predicted_class.item()}")
微調 ViT 的完整步驟與實踐技巧
1. 資料準備與增強
微調前需準備自定義資料集,建議採用以下策略:
- 使用 ImageNet 預訓練權重作為初始化
- 針對小資料集使用較強的資料增強(MixUp、CutMix)
- 可採用微分學習率:淺層使用較低學習率,深層使用較高學習率
2. 微調訓練程式碼
import torch.optim as optim
from torch.utils.data import DataLoader
# 修改分類頭為你的類別數
num_classes = 10
model.head = torch.nn.Linear(model.head.in_features, num_classes)
# 訓練參數
learning_rate = 1e-4
optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=0.05)
criterion = torch.nn.CrossEntropyLoss()
# 訓練迴圈
for epoch in range(10):
model.train()
for images, labels in train_loader:
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")
3. 實用微調技巧
- 層級凍存(Layer Freezing):初期凍存淺層編碼器,僅訓練分類頭
- 學習率排程:使用 Warmup + Cosine Annealing 效果更佳
- 混合精度訓練:使用 torch.cuda.amp 加速訓練並節省記憶體
常見應用場景與模型比較
ViT 適用於多種電腦視覺任務:
- 影像分類:ImageNet Top-1 準確率可達 88.5%(ViT-H/14)
- 目標檢測:DETR、Swim Transformer 等延伸模型
- 語義分割:Swin Transformer 搭配 UperNet
- 影像生成:DiT(Diffusion Transformer)
相比 CNN,ViT 在大規模訓練時效能更佳,但在小資料集上可能需要更多資料增強或遷移學習技巧。