2313 字
12 分钟
给神经网络加上保存和加载功能

训练了半天的模型,结果程序一关就全没了。这种事发生过一次之后,我就学乖了:模型必须能保存。

今天就来补上这一课:给之前手写的 MNIST 分类器加上保存、加载和推理功能。

虽然就是几行代码的事,但这是让模型从”玩具”变成”工具”的关键一步。

为什么需要保存模型#

训练一个模型需要时间。MNIST 虽然只要几秒钟,但复杂模型可能要几小时甚至几天。

如果每次使用都要重新训练,那就太浪费了。正确的做法是:

  1. 训练一次,保存参数
  2. 需要时加载,直接推理
  3. 更新时微调,从保存的参数继续训练

这就是模型保存和加载的意义。

NOTE

在深度学习框架里,保存模型通常叫 checkpoint 或 model serialization。PyTorch 用 torch.save(),TensorFlow 用 model.save()。我们今天用 numpy 的 np.savez() 来实现。

模型里有什么需要保存#

对于我们的两层神经网络,需要保存的就是参数:

第一层:

  • W1:权重矩阵 (784, 128)
  • b1:偏置向量 (1, 128)

第二层:

  • W2:权重矩阵 (128, 10)
  • b2:偏置向量 (1, 10)

就这四个数组。把它们保存到文件里,下次加载回来就能直接用。

TIP

不需要保存训练过程中的中间变量(比如 z1a1),那些都是前向传播时计算出来的。只需要保存训练好的参数。

实现保存功能#

np.savez() 把参数保存成 .npz 文件。

def save(self, path="mnist_model.npz"):
"""
保存模型参数
参数:
path: 保存路径
"""
np.savez(
path,
W1=self.W1, b1=self.b1,
W2=self.W2, b2=self.b2
)
print(f"Model saved to {path}")

np.savez() 会把多个数组打包成一个压缩文件。文件格式是二进制,比文本文件小得多。

使用示例#

# 训练完成后保存
model.save("mnist_model.npz")

输出:

Model saved to mnist_model.npz

文件大小大概几百 KB。相比训练数据(几十 MB),这个文件很小。

实现加载功能#

np.load() 把参数从文件里读回来。

def load(self, path="mnist_model.npz"):
"""
加载模型参数
参数:
path: 模型文件路径
"""
data = np.load(path)
self.W1 = data["W1"]
self.b1 = data["b1"]
self.W2 = data["W2"]
self.b2 = data["b2"]
print(f"Model loaded from {path}")

np.load() 返回一个字典,用参数名作为 key 就能取出对应的数组。

使用示例#

# 创建模型(参数随机初始化)
model = MinimalClassifier(input_dim=784, hidden_dim=128, output_dim=10, lr=0.01)
# 加载训练好的参数
model.load("mnist_model.npz")

输出:

Model loaded from mnist_model.npz

加载之后,模型的参数就变成了训练好的值,可以直接用来推理。

实现推理功能#

推理就是前向传播,但不需要反向传播。

def predict(self, X):
"""
推理(预测)
参数:
X: 输入图像,形状 (batch_size, 28, 28, 1)
返回:
预测类别,形状 (batch_size,)
"""
y_pred = self.forward(X)
return np.argmax(y_pred, axis=1)

forward() 返回的是概率分布(10 个类别的概率),np.argmax() 取概率最大的那个类别作为预测结果。

使用示例#

# 加载模型
model.load("mnist_model.npz")
# 推理
predictions = model.predict(X_test[:10])
print("预测结果:", predictions)

输出:

预测结果: [7 2 1 0 4 1 4 9 5 9]

完整的工作流程#

训练、保存、加载、推理的完整流程:

第一次运行:训练并保存#

import numpy as np
import tqdm
np.random.seed(0)
# 加载数据
data_dir = "./mnist"
X_train, y_train, X_test, y_test = load_mnist_from_local(data_dir)
# 初始化模型
model = MinimalClassifier(input_dim=784, hidden_dim=128, output_dim=10, lr=0.01)
# timeit
start_time = time.time()
# 训练
epochs = 5
batch_size = 64
loss_list = []
for epoch in range(epochs):
pbar = tqdm.tqdm(range(0, len(X_train), batch_size),
desc=f"Epoch {epoch+1}/{epochs}")
for i in pbar:
X_batch = X_train[i:i+batch_size]
y_batch = y_train[i:i+batch_size]
y_pred = model.forward(X_batch)
loss = cross_entropy(y_pred, y_batch)
loss_list.append(loss)
model.backward(y_batch)
model.step()
pbar.set_postfix({"loss": f"{loss:.4f}"})
# 测试
y_pred_test = model.forward(X_test)
pred_classes = np.argmax(y_pred_test, axis=1)
true_classes = np.argmax(y_test, axis=1)
acc = (pred_classes == true_classes).mean()
print(f"Test accuracy: {acc:.4f}")
# 保存模型
model.save("mnist_model.npz")
# timeit
end_time = time.time()
print(f"Time taken: {end_time - start_time:.2f} seconds")

输出:

Epoch 1/5: 100%| 938/938 [00:01<00:00, 473.26it/s, loss=0.9948]
Epoch 2/5: 100%| 938/938 [00:01<00:00, 474.05it/s, loss=0.3770]
Epoch 3/5: 100%| 938/938 [00:02<00:00, 433.10it/s, loss=0.2485]
Epoch 4/5: 100%| 938/938 [00:01<00:00, 505.55it/s, loss=0.1981]
Epoch 5/5: 100%| 938/938 [00:01<00:00, 663.14it/s, loss=0.1713]
Test accuracy: 0.9079
Model saved to mnist_model.npz
Time taken: 7.42 seconds

训练完成,模型保存到 mnist_model.npz

第二次运行:直接加载推理#

把训练代码注释掉,只保留加载和推理:

import numpy as np
import time
# some data structure are omitted
# 加载数据
data_dir = "./mnist"
X_train, y_train, X_test, y_test = load_mnist_from_local(data_dir)
# 创建模型(参数随机初始化)
model = MinimalClassifier(input_dim=784, hidden_dim=128, output_dim=10, lr=0.01)
time_start = time.time()
# 加载训练好的参数
model.load("mnist_model.npz")
# 推理
y_pred_test = model.forward(X_test)
pred_classes = np.argmax(y_pred_test, axis=1)
true_classes = np.argmax(y_test, axis=1)
acc = (pred_classes == true_classes).mean()
print(f"Test accuracy: {acc:.4f}")
end_time = time.time()
print(f"Time taken: {end_time - start_time:.2f} seconds")

输出:

Model loaded from mnist_model.npz
Test accuracy: 0.9079
Time taken: 0.03 seconds

准确率和训练时一样(0.9079),说明参数成功保存和加载了。而且这次运行只花了不到 0.1 秒,因为不需要训练,直接加载就能用。

IMPORTANT

加载模型时,网络结构(input_dim、hidden_dim、output_dim)必须和训练时一致。如果不一致,参数的形状会对不上,会报错。

保存和加载的时间对比#

让我们对比一下训练和加载的时间:

第一次运行(训练+保存):

  • 训练 5 个 epoch:约 8 秒
  • 保存模型:< 0.1 秒
  • 总时间:约 7.42 秒

第二次运行(加载+推理):

  • 加载模型:< 0.1 秒
  • 推理 10000 张测试图像:< 1 秒
  • 总时间:约 0.03 秒

快了 250 倍。对于更复杂的模型,差距会更大。 例如LLM的训练通常可能长达几个月,但是推断只需要几毫秒。

可视化推理结果#

随机挑几张测试集的图像,看看模型预测得对不对。

import matplotlib.pyplot as plt
import random
# 加载模型
model.load("mnist_model.npz")
# 随机挑 10 张图像
indices = random.sample(range(len(X_test)), 10)
plt.figure(figsize=(15, 3))
for i, idx in enumerate(indices):
# 推理
img = X_test[idx:idx+1]
pred_class = model.predict(img)[0]
true_class = np.argmax(y_test[idx])
# 可视化
plt.subplot(2, 5, i+1)
plt.imshow(img.squeeze(), cmap="gray")
# 标题:绿色表示正确,红色表示错误
color = "green" if pred_class == true_class else "red"
plt.title(f"Pred: {pred_class}\nTrue: {true_class}", color=color)
plt.axis("off")
plt.tight_layout()
plt.show()

大部分图像会显示绿色标题(预测正确),偶尔会有红色的(预测错误)。根据 90.79% 的准确率,平均 10 张图像里会有 1 张预测错误。

模型文件里有什么#

.npz 文件是一个压缩的 numpy 数组集合。可以用代码查看里面的内容:

import numpy as np
# 查看模型文件
data = np.load("mnist_model.npz")
print("模型包含的参数:")
for key in data.files:
print(f" {key}: shape={data[key].shape}, dtype={data[key].dtype}")
# 计算总参数量
total_params = sum(data[key].size for key in data.files)
print(f"\n总参数量: {total_params:,}")

输出:

模型包含的参数:
W1: shape=(784, 128), dtype=float64
b1: shape=(1, 128), dtype=float64
W2: shape=(128, 10), dtype=float64
b2: shape=(1, 10), dtype=float64
总参数量: 101,514

就这四个数组。总共 (784×128 + 128) + (128×10 + 10) = 101,514 个参数。

TIP

模型大小主要取决于参数数量。这个两层网络只有 10 万个参数,文件大概 800KB。真实的深度模型可能有几百万甚至几十亿个参数,文件会大得多(几 GB 甚至几十 GB)。

为什么不保存训练代码#

有人可能会问:为什么不把训练代码也保存下来?

因为:

  1. 训练代码和模型参数是分离的:代码定义了网络结构,参数是训练的结果。
  2. 代码通常是版本控制的:用 git 管理代码,用文件管理参数。
  3. 参数可以在不同代码版本间复用:只要网络结构不变,旧版本训练的参数可以在新版本代码中加载。

正确的做法是:

  • 代码:用 git 管理
  • 参数:用文件保存(.npz.pth.h5 等)
  • 训练配置:用配置文件(YAML、JSON)记录超参数(learning rate、batch size、epochs 等)

完整代码#

把保存、加载、推理功能整合到模型类里:

class MinimalClassifier:
def __init__(self, input_dim, hidden_dim, output_dim, lr=0.01):
self.lr = lr
self.W1 = np.random.randn(input_dim, hidden_dim) * 0.01
self.b1 = np.zeros((1, hidden_dim))
self.W2 = np.random.randn(hidden_dim, output_dim) * 0.01
self.b2 = np.zeros((1, output_dim))
def forward(self, X):
self.X = flatten(X)
self.z1 = self.X @ self.W1 + self.b1
self.a1 = relu(self.z1)
self.z2 = self.a1 @ self.W2 + self.b2
self.y_pred = softmax(self.z2)
return self.y_pred
def backward(self, y_true):
dL_dz2 = softmax_cross_entropy_grad(self.y_pred, y_true)
self.dW2 = self.a1.T @ dL_dz2
self.db2 = np.sum(dL_dz2, axis=0, keepdims=True)
dL_da1 = dL_dz2 @ self.W2.T
dL_dz1 = dL_da1 * relu_grad(self.z1)
self.dW1 = self.X.T @ dL_dz1
self.db1 = np.sum(dL_dz1, axis=0, keepdims=True)
def step(self):
self.W2 -= self.lr * self.dW2
self.b2 -= self.lr * self.db2
self.W1 -= self.lr * self.dW1
self.b1 -= self.lr * self.db1
# ========== 推理 ==========
def predict(self, X):
"""推理(预测类别)"""
y_pred = self.forward(X)
return np.argmax(y_pred, axis=1)
# ========== 保存 ==========
def save(self, path="mnist_model.npz"):
"""保存模型参数"""
np.savez(
path,
W1=self.W1, b1=self.b1,
W2=self.W2, b2=self.b2
)
print(f"Model saved to {path}")
# ========== 加载 ==========
def load(self, path="mnist_model.npz"):
"""加载模型参数"""
data = np.load(path)
self.W1 = data["W1"]
self.b1 = data["b1"]
self.W2 = data["W2"]
self.b2 = data["b2"]
print(f"Model loaded from {path}")

小结#

这就是模型保存、加载和推理的完整实现。虽然就是几行代码,但它让模型从”训练完就扔”变成了”训练一次,永久使用”。

核心要点:

  1. 保存:用 np.savez() 把参数打包成文件
  2. 加载:用 np.load() 把参数读回来
  3. 推理:前向传播 + argmax 得到预测类别

工作流程:

  1. 第一次运行:训练模型,保存参数
  2. 后续运行:加载参数,直接推理

时间对比:

  • 训练:约 7.42 秒
  • 加载+推理:约 0.03 秒(快 250 倍)

这就是为什么在生产环境里,我们总是先训练好模型,然后把参数部署到服务器上,而不是每次请求都重新训练。

[!quote] “Don’t repeat yourself.”
— The Pragmatic Programmer

训练一次,保存参数,多次使用。这是机器学习工程的基本原则。

上一篇:ML4-从零手搓一个 MNIST 分类器 下一篇:ML6-Adam 优化器:让训练快三倍的秘密


Works Cited

Goodfellow, Ian, et al. Deep Learning. MIT Press, 2016.

给神经网络加上保存和加载功能
https://blog.lishuyu.top/posts/ml5-给神经网络加上保存和加载功能/
作者
猫猫魔女
发布于
2025-11-20
许可协议
CC BY-NC-SA 4.0