跳到主要内容

L6-实战训练

本项目演示如何用 PyTorch + torchvision 实现 图像二分类任务(如猫狗分类、自定义两类数据集),并通过迁移学习快速训练出一个高准确率的模型。

为起到实战锻炼效果,本部分开始需要同学自己下载数据集合,自己调优模型,自己评测自己的模型效果。


1. 数据准备

1.1 文件结构

准备一个目录 data/,并放置如下结构的图像数据(支持任意两类,示例为 Cat 与 Dog):

data/
train/
Cat/ xxx.jpg
yyy.jpg
Dog/ aaa.jpg
bbb.jpg
val/
Cat/ ccc.jpg
ddd.jpg
Dog/ eee.jpg
fff.jpg
  • train/:训练集(建议每类 ≥100 张图像)
  • val/:验证集(每类 ≥20 张图像)

1.2 数据来源

  • 可从 Kaggle Cats vs Dogs 数据集下载,手动划分 train/val/
  • 或者准备自己的两类图片(比如“口罩 vs 非口罩”、“有缺陷 vs 正常”)。

2. 环境配置

conda create -n resnet python=3.12 -y
conda activate resnet

# 安装 PyTorch 与 torchvision(根据你的 CUDA 版本替换)
pip install torch torchvision matplotlib scikit-learn

验证安装是否成功:

python -c "import torch; print(torch.__version__); print(torch.cuda.is_available())"

3. 训练代码

保存为 finetune_resnet.py

import argparse, time
import torch, torch.nn as nn
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
import copy, os

def get_loaders(data_dir, batch_size=32):
# 数据增强 & 预处理
train_t = transforms.Compose([
transforms.Resize((224,224)),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(0.2,0.2,0.2,0.1),
transforms.ToTensor()
])
val_t = transforms.Compose([
transforms.Resize((224,224)),
transforms.ToTensor()
])
train_ds = datasets.ImageFolder(os.path.join(data_dir, "train"), transform=train_t)
val_ds = datasets.ImageFolder(os.path.join(data_dir, "val"), transform=val_t)
return (DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2),
DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=2),
train_ds.classes)

def train(args):
device = "cuda" if torch.cuda.is_available() else "cpu"
train_dl, val_dl, classes = get_loaders(args.data_dir, args.batch_size)

# 加载预训练 ResNet18
model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
for p in model.parameters(): # 冻结特征层
p.requires_grad = False
in_f = model.fc.in_features
model.fc = nn.Linear(in_f, len(classes)) # 替换最后分类层
model.to(device)

crit = nn.CrossEntropyLoss()
optim = torch.optim.Adam(model.fc.parameters(), lr=args.lr)

best_w, best_acc = copy.deepcopy(model.state_dict()), 0.0
for ep in range(1, args.epochs+1):
# 训练
model.train()
tr_loss, tr_correct, tr_total = 0.0, 0, 0
t0 = time.time()
for x,y in train_dl:
x, y = x.to(device), y.to(device)
optim.zero_grad()
logit = model(x)
loss = crit(logit, y)
loss.backward()
optim.step()
tr_loss += loss.item() * x.size(0)
tr_correct += (logit.argmax(1) == y).sum().item()
tr_total += x.size(0)
tr_acc = tr_correct / tr_total

# 验证
model.eval()
va_loss, va_correct, va_total = 0.0, 0, 0
with torch.no_grad():
for x,y in val_dl:
x, y = x.to(device), y.to(device)
logit = model(x)
loss = crit(logit, y)
va_loss += loss.item() * x.size(0)
va_correct += (logit.argmax(1) == y).sum().item()
va_total += x.size(0)
va_acc = va_correct / va_total

print(f"Epoch {ep}/{args.epochs} | "
f"train acc {tr_acc:.3f} | val acc {va_acc:.3f} | "
f"time {time.time()-t0:.1f}s")

if va_acc > best_acc:
best_acc = va_acc
best_w = copy.deepcopy(model.state_dict())

model.load_state_dict(best_w)
torch.save({"state_dict": model.state_dict(), "classes": classes}, "best_resnet18.pt")
print(f"保存最优模型:best_resnet18.pt(val acc={best_acc:.3f})")

if __name__ == "__main__":
ap = argparse.ArgumentParser()
ap.add_argument("--data_dir", type=str, required=True)
ap.add_argument("--epochs", type=int, default=5)
ap.add_argument("--batch_size", type=int, default=32)
ap.add_argument("--lr", type=float, default=3e-4)
args = ap.parse_args()
train(args)

运行训练:

python finetune_resnet.py --data_dir data --epochs 5 --batch_size 32 --lr 3e-4

4. 推理代码

保存为 infer.py

import argparse, torch
from PIL import Image
from torchvision import transforms, models
import torch.nn as nn

def load_model(ckpt_path, labels=None):
ckpt = torch.load(ckpt_path, map_location="cpu")
classes = ckpt.get("classes", None)
if labels:
classes = labels.split(",")
model = models.resnet18()
model.fc = nn.Linear(model.fc.in_features, len(classes))
model.load_state_dict(ckpt["state_dict"])
model.eval()
return model, classes

if __name__ == "__main__":
ap = argparse.ArgumentParser()
ap.add_argument("--ckpt", required=True)
ap.add_argument("--img", required=True)
ap.add_argument("--labels", default=None, help="Comma separated class names if not saved in ckpt.")
args = ap.parse_args()

model, classes = load_model(args.ckpt, args.labels)
t = transforms.Compose([transforms.Resize((224,224)), transforms.ToTensor()])
x = t(Image.open(args.img).convert("RGB")).unsqueeze(0)
with torch.no_grad():
prob = torch.softmax(model(x), dim=1)[0]
idx = prob.argmax().item()
print(f"预测:{classes[idx]}(概率 {prob[idx].item():.3f})")

推理示例:

python infer.py --ckpt best_resnet18.pt --img data/val/Cat/cat001.jpg --labels "Cat,Dog"

5. 模型评估与可视化

训练完成后,可以进一步分析模型表现:

5.1 混淆矩阵

from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import matplotlib.pyplot as plt
import torch

# 假设 val_dl, model, classes 已定义
y_true, y_pred = [], []
device = "cuda" if torch.cuda.is_available() else "cpu"

model.eval()
with torch.no_grad():
for x,y in val_dl:
x, y = x.to(device), y.to(device)
logit = model(x)
preds = logit.argmax(1)
y_true.extend(y.cpu().numpy())
y_pred.extend(preds.cpu().numpy())

cm = confusion_matrix(y_true, y_pred, labels=range(len(classes)))
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=classes)
disp.plot(cmap="Blues")
plt.title("Confusion Matrix")
plt.show()

5.2 可视化预测结果

import matplotlib.pyplot as plt

model.eval()
images, labels = next(iter(val_dl))
with torch.no_grad():
outputs = model(images.to(device))
preds = outputs.argmax(1).cpu()

plt.figure(figsize=(10,5))
for i in range(8):
plt.subplot(2,4,i+1)
plt.imshow(images[i].permute(1,2,0).numpy())
plt.title(f"T:{classes[labels[i]]}
P:{classes[preds[i]]}")
plt.axis("off")
plt.tight_layout()
plt.show()

6. 进阶挑战

  • 解冻 ResNet 的最后几层,进行 微调(fine-tuning),提升精度。
  • 增加 学习率调度器(如 CosineAnnealingLR)。
  • 保存/加载 整个模型(不仅是 state_dict)。
  • 使用 torch.onnx.export 导出 ONNX 模型,方便部署。