这种入门级的 PyTorch 代码在网上到处都可以找到,GPT 也可以完美生成,但是还是选择记录一下,因为这里封装了一个简单的 Trainer,可能有一点复用的价值,如果我后面还要写简单的神经网络的话。

mytrainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import logging
import torch


class MyTrainer:
def __init__(self, model, loss_fn, optimizer, device=None, logger=None):
self.model = model
self.loss_fn = loss_fn
self.optimizer = optimizer
self.device = device or torch.device("cpu")
self.model.to(self.device)
self.best_model_state = None

if logger is None:
self.logger = logging.getLogger(__name__)
if not self.logger.handlers:
handler = logging.StreamHandler()
formatter = logging.Formatter(
"[%(asctime)s][%(levelname)s] %(message)s", "%Y-%m-%d %H:%M:%S"
)
handler.setFormatter(formatter)
self.logger.addHandler(handler)
self.logger.setLevel(logging.INFO)
else:
self.logger = logger

def fit(
self, train_loader, val_loader=None, epochs=20, patience=None, log_interval=1
):
best_val_loss = float("inf")
epochs_no_improve = 0
train_losses, val_losses = [], []

for epoch in range(epochs):
self.model.train()
batch_train_losses = []
for x_batch, y_batch in train_loader:
x_batch, y_batch = x_batch.to(self.device), y_batch.to(self.device)
self.optimizer.zero_grad()
outputs = self.model(x_batch)
loss = self.loss_fn(outputs, y_batch)
loss.backward()
self.optimizer.step()
batch_train_losses.append(loss.item())
train_loss = sum(batch_train_losses) / len(batch_train_losses)
train_losses.append(train_loss)

if val_loader is not None:
self.model.eval()
batch_val_losses = []
with torch.no_grad():
for x_batch, y_batch in val_loader:
x_batch, y_batch = x_batch.to(self.device), y_batch.to(
self.device
)
outputs = self.model(x_batch)
loss = self.loss_fn(outputs, y_batch)
batch_val_losses.append(loss.item())
val_loss = sum(batch_val_losses) / len(batch_val_losses)
val_losses.append(val_loss)

if val_loss < best_val_loss: # Early stopping
best_val_loss = val_loss
self.best_model_state = self.model.state_dict()
epochs_no_improve = 0
else:
epochs_no_improve += 1
if patience is not None and epochs_no_improve >= patience:
print(f"Early stopping at epoch {epoch+1}")
self.model.load_state_dict(self.best_model_state)
break
else:
val_losses.append(None)

if (epoch + 1) % log_interval == 0:
msg = f"Epoch {epoch+1}/{epochs} - Train Loss: {train_loss:.4f}"
if val_loss is not None:
msg += f", Val Loss: {val_loss:.4f}"
self.logger.info(msg)

if self.best_model_state is not None:
self.model.load_state_dict(self.best_model_state)

self.train_losses = train_losses
self.val_losses = val_losses
return train_losses, val_losses

线性回归

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt

from mytrainer import MyTrainer


# ===========================
# y = 2 * x + 3 + noise
# ===========================
class LinearDataset(Dataset):
def __init__(self, n_samples=200):
super().__init__()
self.x = torch.rand(n_samples, 1)
self.y = 2 * self.x + 3 + 0.1 * torch.randn(n_samples, 1)

def __len__(self):
return len(self.x)

def __getitem__(self, idx):
return self.x[idx], self.y[idx]


train_dataset = LinearDataset(200)
val_dataset = LinearDataset(50)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)


class LinearModel(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(1, 1)

def forward(self, x):
return self.linear(x)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = LinearModel().to(device)

loss_fn = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.1)

trainer = MyTrainer(model, loss_fn, optimizer, device)
train_losses, val_losses = trainer.fit(train_loader, val_loader, epochs=10, patience=3)

fig, ax = plt.subplots(figsize=(8, 5))
ax.plot(train_losses, label="Train Loss", marker="o")
ax.plot(val_losses, label="Validation Loss", marker="x")
ax.set_xlabel("Epoch")
ax.set_ylabel("Loss")
ax.set_title("Linear Regression Loss")
ax.legend()
ax.grid(True)
fig.savefig("linear_loss.png", dpi=300)
plt.show()

w = model.linear.weight.item()
b = model.linear.bias.item()
print(f"Trained weight: {w:.4f}, bias: {b:.4f}")

MINST,MLP

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt

from mytrainer import MyTrainer

transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
)

train_dataset = datasets.MNIST(
root="./data", train=True, download=True, transform=transform
)
val_dataset = datasets.MNIST(
root="./data", train=False, download=True, transform=transform
)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)


class SimpleMLP(nn.Module):
def __init__(self):
super().__init__()
self.net = nn.Sequential(
nn.Flatten(), nn.Linear(28 * 28, 128), nn.ReLU(), nn.Linear(128, 10)
)

def forward(self, x):
return self.net(x)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SimpleMLP().to(device)


loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

trainer = MyTrainer(model, loss_fn, optimizer, device)
train_losses, val_losses = trainer.fit(train_loader, val_loader, epochs=10, patience=3)


fig, ax = plt.subplots(figsize=(8, 5))

ax.plot(train_losses, label="Train Loss", marker="o")
ax.plot(val_losses, label="Validation Loss", marker="x")

ax.set_xlabel("Epoch")
ax.set_ylabel("Loss")
ax.set_title("Training and Validation Loss")
ax.legend()
ax.grid(True)

fig.savefig("loss_curve.png", dpi=300)
plt.show()