Vision Transformer 是什麼?為何你需要了解它

Vision Transformer(ViT)是一種將 Transformer 架構應用於電腦視覺任務的深度學習模型。傳統 CNN(卷積神經網路)依賴局部特徵提取,而 ViT 將影像切分為多個小區塊(patches),將每個區塊視為一個「詞彙」,透過自注意力機制(Self-Attention)學習全域關聯。

ViT 的核心優勢在於:可擴展性強、訓練效率高、在大型資料集上表現優異。2020 年 Google 發表的論文「An Image is Worth 16x16 Words」正式開啟了 Transformer 進軍電腦視覺的時代。

ViT 架構解析:從輸入到輸出

ViT 的運作流程可分為以下步驟:

  • 影像切塊(Patch Embedding):將輸入影像(例如 224×224)切割成固定大小的區塊(如 16×16),每個區塊展平為向量。
  • 線性投射(Linear Projection):透過線性層將每個區塊向量映射到隱藏維度空間。
  • 位置編碼(Position Embedding):加入位置資訊,讓模型理解各區塊的空間關係。
  • CLS token:附加一個額外的分類 token,用於最終分類任務。
  • Transformer Encoder:透過多層 Encoder 堆疊,計算自注意力並輸出特徵。
  • 分類頭(Classification Head):取 CLS token 輸出,透過線性層進行分類。

如何使用 PyTorch 載入與預訓練 ViT

以下示範如何使用 timm 庫快速載入預訓練的 ViT 模型:

import timm
import torch

# 列出可用的 ViT 模型
models = timm.list_models('vit*')
print(models[:10])

# 載入預訓練的 ViT-B/16
model = timm.create_model('vit_base_patch16_224', pretrained=True)
model.eval()

# 準備輸入影像
from PIL import Image
import torchvision.transforms as transforms

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                         std=[0.229, 0.224, 0.225])
])

# 推論範例
img = Image.open('sample.jpg')
input_tensor = transform(img).unsqueeze(0)

with torch.no_grad():
    output = model(input_tensor)
    predicted_class = output.argmax(dim=1)
print(f"預測類別: {predicted_class.item()}")

微調 ViT 的完整步驟與實踐技巧

1. 資料準備與增強

微調前需準備自定義資料集,建議採用以下策略:

  • 使用 ImageNet 預訓練權重作為初始化
  • 針對小資料集使用較強的資料增強(MixUp、CutMix)
  • 可採用微分學習率:淺層使用較低學習率,深層使用較高學習率

2. 微調訓練程式碼

import torch.optim as optim
from torch.utils.data import DataLoader

# 修改分類頭為你的類別數
num_classes = 10
model.head = torch.nn.Linear(model.head.in_features, num_classes)

# 訓練參數
learning_rate = 1e-4
optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=0.05)
criterion = torch.nn.CrossEntropyLoss()

# 訓練迴圈
for epoch in range(10):
    model.train()
    for images, labels in train_loader:
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
    print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")

3. 實用微調技巧

  • 層級凍存(Layer Freezing):初期凍存淺層編碼器,僅訓練分類頭
  • 學習率排程:使用 Warmup + Cosine Annealing 效果更佳
  • 混合精度訓練:使用 torch.cuda.amp 加速訓練並節省記憶體

常見應用場景與模型比較

ViT 適用於多種電腦視覺任務:

  • 影像分類:ImageNet Top-1 準確率可達 88.5%(ViT-H/14)
  • 目標檢測:DETR、Swim Transformer 等延伸模型
  • 語義分割:Swin Transformer 搭配 UperNet
  • 影像生成:DiT(Diffusion Transformer)

相比 CNN,ViT 在大規模訓練時效能更佳,但在小資料集上可能需要更多資料增強或遷移學習技巧。