写了那么多全连接网络之后,终于到了 CNN 的时候。
用纯 numpy 手写一个完整的 CNN,包括卷积层、池化层和反向传播,是个相当折磨人的工作。更要命的是,Python 的多层循环慢得让人绝望:训练一个 epoch 需要 40 分钟。
但用 Numba JIT 编译加速之后,速度快了十几倍,训练时间从 40 分钟压缩到 3 分钟。准确率也从 96.24%(全连接)提升到了 98.06%。
虽然累到怀疑人生,但看到测试准确率破 98%,还是挺有成就感的。
为什么需要 CNN
全连接网络在 MNIST 上能跑到 96%,但有个致命问题:完全忽略了图像的空间结构。
把 28×28 的图像展平成 784 维向量,意味着:
- 相邻像素的关系丢失了
- 旋转、平移等变换敏感性很高
- 参数量巨大(784×128 = 100,352 个参数)
CNN 通过卷积和池化解决了这些问题:
- 卷积:提取局部特征(边缘、纹理、形状)
- 池化:降低分辨率,增强平移不变性
- 权值共享:同一个卷积核在整个图像上滑动,参数量大大减少
NOTELeNet(1998)是最早的 CNN 之一,在 MNIST 上就能达到 99% 以上的准确率。今天我们要实现的就是一个简化版的 LeNet。
网络架构
我们要实现的 CNN 结构:
Input (28×28×1) ↓Conv2D (8 filters, 3×3 kernel) → (26×26×8) ↓ReLU ↓Conv2D (16 filters, 3×3 kernel) → (24×24×16) ↓ReLU ↓MaxPool (2×2) → (12×12×16) ↓Flatten → (2304,) ↓Dense (10 neurons) → (10,) ↓Softmax层数说明:
- Conv1:1 → 8 通道,提取 8 种基础特征
- Conv2:8 → 16 通道,提取 16 种高级特征
- MaxPool:2×2 下采样,降低计算量
- Dense:全连接层,输出 10 个类别的概率
相比全连接网络(784 → 128 → 10),这个 CNN 虽然层数多,但参数量其实更少。
Python 循环的性能灾难
在实现 CNN 之前,我们先来看看为什么需要优化。
卷积操作本质上是多层嵌套循环:
def conv2d_naive(X, W, b): B, H, W_in, C_in = X.shape out_c, _, k, _ = W.shape out_h = H - k + 1 out_w = W_in - k + 1 out = np.zeros((B, out_h, out_w, out_c))
for b in range(B): # batch for oc in range(out_c): # output channel for i in range(out_h): # height for j in range(out_w): # width for ic in range(C_in): # input channel for ki in range(k): # kernel height for kj in range(k): # kernel width out[b,i,j,oc] += X[b,i+ki,j+kj,ic] * W[oc,ic,ki,kj] out[b,i,j,oc] += b[oc,0] return out这是 7 层嵌套循环。对于 MNIST 的一个 batch(64 张图像):
- B = 64
- out_c = 8
- out_h = 26
- out_w = 26
- C_in = 1
- k = 3
总循环次数:64 × 8 × 26 × 26 × 1 × 3 × 3 ≈ 310 万次。
每个 epoch 有 938 个 batch,还有前向传播和反向传播,总循环次数是天文数字。
实测结果:
- 纯 Python 循环:训练一个 epoch 需要 40 分钟
- 用 Numba JIT 编译:训练一个 epoch 只需要 3 分钟
差距 13 倍。
WARNINGPython 的循环慢是因为它是解释型语言,每次循环都要做类型检查、引用计数等开销。对于深度嵌套循环,这些开销会累积到难以接受的程度。
什么是 Numba
Numba 是一个 JIT(Just-In-Time)编译器,能把 Python 代码编译成机器码。
核心思想:
- 在第一次调用时,把 Python 函数编译成机器码
- 之后的调用直接执行机器码,跳过 Python 解释器
- 对于数值计算和循环,速度能提升几十倍
使用方法:
from numba import njit
@njit(cache=True, fastmath=True)def my_function(x): # 你的代码 return result@njit 装饰器会把函数编译成机器码。cache=True 表示编译结果会缓存,下次运行不需要重新编译。fastmath=True 允许一些不精确但更快的数学运算。
TIPNumba 特别适合:多层循环、数值计算、numpy 数组操作。不适合:字典、列表、字符串等 Python 对象。
实现卷积层(带 Numba 加速)
前向传播
from numba import njit
@njit(cache=True, fastmath=True)def conv_forward_jit(X, W, b, k): """ 卷积前向传播(Numba 加速) 参数: X: 输入,形状 (B, H, W, C_in) W: 卷积核,形状 (out_c, C_in, k, k) b: 偏置,形状 (out_c, 1) k: 卷积核大小 返回: out: 输出,形状 (B, out_h, out_w, out_c) """ B, H, W_in, C = X.shape out_c = W.shape[0] out_h = H - k + 1 out_w = W_in - k + 1 out = np.zeros((B, out_h, out_w, out_c), dtype=X.dtype)
for b_idx in range(B): for oc in range(out_c): for i in range(out_h): for j in range(out_w): acc = 0.0 for ic in range(C): for ki in range(k): for kj in range(k): acc += X[b_idx, i + ki, j + kj, ic] * W[oc, ic, ki, kj] out[b_idx, i, j, oc] = acc + b[oc, 0] return out这段代码和 naive 版本几乎一样,唯一的区别是加了 @njit 装饰器。
但性能差距是天壤之别:
- Naive Python:40 分钟/epoch
- Numba JIT:3 分钟/epoch
反向传播
@njit(cache=True, fastmath=True)def conv_backward_jit(X, grad, W, k): """ 卷积反向传播(Numba 加速) 参数: X: 输入,形状 (B, H, W, C_in) grad: 输出梯度,形状 (B, out_h, out_w, out_c) W: 卷积核,形状 (out_c, C_in, k, k) k: 卷积核大小 返回: dW: 卷积核梯度 db: 偏置梯度 dX: 输入梯度 """ B, H, W_out, out_c = grad.shape _, XH, XW, in_c = X.shape
dW = np.zeros_like(W) db = np.zeros((out_c, 1), dtype=grad.dtype) dX = np.zeros_like(X)
for b_idx in range(B): for oc in range(out_c): for ic in range(in_c): for i in range(H): for j in range(W_out): g = grad[b_idx, i, j, oc] if g == 0.0: continue for ki in range(k): for kj in range(k): dW[oc, ic, ki, kj] += g * X[b_idx, i + ki, j + kj, ic] dX[b_idx, i + ki, j + kj, ic] += g * W[oc, ic, ki, kj] db[oc, 0] += np.sum(grad[b_idx, :, :, oc]) return dW, db, dX反向传播的循环更复杂,嵌套更深。Numba 的加速效果更明显。
封装成类
class Conv2D: def __init__(self, in_channels, out_channels, kernel_size): self.in_c = in_channels self.out_c = out_channels self.k = kernel_size self.W = np.random.randn(out_channels, in_channels, kernel_size, kernel_size) * 0.1 self.b = np.zeros((out_channels, 1))
def forward(self, X): self.X = X self.out = conv_forward_jit(X, self.W, self.b, self.k) return self.out
def backward(self, grad): self.dW, self.db, self.dX = conv_backward_jit(self.X, grad, self.W, self.k) return self.dX
def step(self, lr): self.W -= lr * self.dW self.b -= lr * self.db使用示例:
conv = Conv2D(in_channels=1, out_channels=8, kernel_size=3)out = conv.forward(X) # 前向传播conv.backward(grad) # 反向传播conv.step(lr=0.01) # 参数更新实现池化层(带 Numba 加速)
前向传播
@njit(cache=True, fastmath=True)def maxpool_forward_jit(X): """ 2×2 最大池化前向传播(Numba 加速) 参数: X: 输入,形状 (B, H, W, C) 返回: out: 输出,形状 (B, H//2, W//2, C) argmax: 最大值位置索引 """ B, H, W, C = X.shape out = np.zeros((B, H // 2, W // 2, C), dtype=X.dtype) argmax = np.zeros((B, H // 2, W // 2, C), dtype=np.int64)
for b_idx in range(B): for c in range(C): for i in range(0, H, 2): for j in range(0, W, 2): # 找 2×2 窗口内的最大值 max_val = X[b_idx, i, j, c] max_idx = 0
idx = 1 v = X[b_idx, i, j + 1, c] if v > max_val: max_val = v max_idx = idx
idx += 1 v = X[b_idx, i + 1, j, c] if v > max_val: max_val = v max_idx = idx
idx += 1 v = X[b_idx, i + 1, j + 1, c] if v > max_val: max_val = v max_idx = idx
out[b_idx, i // 2, j // 2, c] = max_val argmax[b_idx, i // 2, j // 2, c] = max_idx return out, argmax池化需要记录最大值的位置(argmax),反向传播时需要用到。
反向传播
@njit(cache=True, fastmath=True)def maxpool_backward_jit(grad, argmax, H_in, W_in): """ 2×2 最大池化反向传播(Numba 加速) 参数: grad: 输出梯度,形状 (B, H//2, W//2, C) argmax: 前向传播时记录的最大值位置 H_in, W_in: 输入的高度和宽度 返回: dX: 输入梯度,形状 (B, H_in, W_in, C) """ B, H2, W2, C = grad.shape dX = np.zeros((B, H_in, W_in, C), dtype=grad.dtype)
for b_idx in range(B): for c in range(C): for i in range(H2): for j in range(W2): idx = argmax[b_idx, i, j, c] bi = idx // 2 bj = idx - bi * 2 dX[b_idx, i * 2 + bi, j * 2 + bj, c] = grad[b_idx, i, j, c] return dX梯度只回传到最大值的位置,其他位置梯度为零。
封装成类
class MaxPool2x2: def forward(self, X): self.X = X out, self.argmax = maxpool_forward_jit(X) return out
def backward(self, grad): B, H, W, C = self.X.shape return maxpool_backward_jit(grad, self.argmax, H, W)完整的 CNN 模型
把所有层组合起来:
class MinimalCNNClassifier: def __init__(self, lr=0.01, use_adam=False): self.lr = lr self.use_adam = use_adam self.adam = TinyAdam(lr=lr) if use_adam else None
# 网络层 self.conv1 = Conv2D(1, 8, 3) # 1→8 通道 self.conv2 = Conv2D(8, 16, 3) # 8→16 通道 self.pool = MaxPool2x2() # 2×2 池化 self.fc = Dense(12 * 12 * 16, 10) # 全连接层
def forward(self, X): # Conv1 + ReLU out = self.conv1.forward(X) out = relu(out) self.after_relu1 = out
# Conv2 + ReLU out = self.conv2.forward(out) out = relu(out) self.after_relu2 = out
# MaxPool out = self.pool.forward(out)
# Flatten + Dense + Softmax out = flatten(out) out = self.fc.forward(out) self.y_pred = softmax(out) return self.y_pred
def backward(self, y_true): # 输出层梯度 grad = softmax_cross_entropy_grad(self.y_pred, y_true)
# Dense 反向传播 grad = self.fc.backward(grad)
# Unflatten grad = grad.reshape(-1, 12, 12, 16)
# MaxPool 反向传播 grad = self.pool.backward(grad)
# Conv2 反向传播 grad = grad * relu_grad(self.after_relu2) grad = self.conv2.backward(grad)
# Conv1 反向传播 grad = grad * relu_grad(self.after_relu1) self.conv1.backward(grad)
def step(self): if self.adam: self.adam.step([ (self.fc.W, self.fc.dW), (self.fc.b, self.fc.db), (self.conv2.W, self.conv2.dW), (self.conv2.b, self.conv2.db), (self.conv1.W, self.conv1.dW), (self.conv1.b, self.conv1.db), ]) else: self.fc.step(self.lr) self.conv2.step(self.lr) self.conv1.step(self.lr)IMPORTANT反向传播的顺序是前向传播的逆序。ReLU 的梯度要在卷积反向传播之前应用(逐元素相乘)。
训练和测试
import numpy as npimport tqdm
np.random.seed(0)
# 加载数据data_dir = "./mnist"X_train, y_train, X_test, y_test = load_mnist_from_local(data_dir)
# 创建模型model = MinimalCNNClassifier(lr=0.001, use_adam=True)
# 训练def train(model, X_train, y_train, epochs, batch_size): loss_list = [] for epoch in range(epochs): pbar = tqdm.tqdm( range(0, len(X_train), batch_size), desc=f"Epoch {epoch+1}/{epochs}" ) for i in pbar: X_batch = X_train[i:i+batch_size] y_batch = y_train[i:i+batch_size]
y_pred = model.forward(X_batch) loss = cross_entropy(y_pred, y_batch) loss_list.append(loss)
model.backward(y_batch) model.step()
pbar.set_postfix({"loss": f"{loss:.4f}"}) return loss_list
loss_list = train(model, X_train, y_train, epochs=3, batch_size=64)
# 测试def test(model, X_test, y_test): y_pred_test = model.forward(X_test) pred = np.argmax(y_pred_test, axis=1) true = np.argmax(y_test, axis=1) acc = (pred == true).mean() print(f"\nTest accuracy: {acc:.4f}") return acc
acc = test(model, X_test, y_test)训练结果
实际运行结果:
Epoch 1/3: 100%| 938/938 [00:54<00:00, 17.24it/s, loss=0.0286]Epoch 2/3: 100%| 938/938 [00:54<00:00, 17.19it/s, loss=0.0267]Epoch 3/3: 100%| 938/938 [00:49<00:00, 18.90it/s, loss=0.0153]
Test accuracy: 0.9806最终结果:
- 训练时间:约 3 分钟(3 个 epoch)
- 测试准确率:98.06%
- 最终 loss:0.0153
对比之前的结果:
| 模型 | 准确率 | 训练时间 | 参数量 |
|---|---|---|---|
| 全连接 (SGD) | 90.79% | ~8 秒 (5 epoch) | 101K |
| 全连接 (Adam) | 96.24% | ~10 秒 (5 epoch) | 101K |
| CNN (Adam + Numba) | 98.06% | ~3 分钟 (3 epoch) | 14K |
| CNN (Adam + Numba) | 98.30% | ~3 分钟 (3 epoch) | 103K |
CNN 虽然训练时间长一些(因为卷积操作更复杂),但准确率提升明显。而且参数量更少(14K vs 101K)。
NOTE如果没有 Numba 加速,训练时间会是 40 分钟 × 3 = 120 分钟(2 小时)。Numba 让训练变得可行。
Numba 的使用技巧
1. 什么时候用 Numba
适合:
- 多层嵌套循环
- 数值计算(加减乘除、数学函数)
- numpy 数组操作
不适合:
- 字典、列表等 Python 对象
- 字符串操作
- 文件 I/O
- 面向对象编程(类方法)
2. 常用参数
@njit(cache=True, fastmath=True, parallel=False)def my_function(x): passcache=True:缓存编译结果,第二次运行不需要重新编译fastmath=True:允许不精确但更快的数学运算parallel=True:自动并行化循环(需要循环之间独立)
3. 调试技巧
Numba 编译后的函数很难调试。如果出错,可以暂时去掉 @njit 装饰器,用纯 Python 运行:
# 调试时去掉装饰器# @njit(cache=True, fastmath=True)def conv_forward_jit(X, W, b, k): # ...确认逻辑正确后,再加回装饰器。
4. 类型推断
Numba 需要推断变量类型。确保变量类型一致:
# 好acc = 0.0 # floatacc += X[i] * W[j] # float + float
# 不好acc = 0 # intacc += X[i] * W[j] # int + float,可能导致类型错误保存和加载模型
给模型加上保存和加载功能:
class MinimalCNNClassifier: # ... 前面的代码 ...
def save(self, path): np.savez( path, conv1_W=self.conv1.W, conv1_b=self.conv1.b, conv2_W=self.conv2.W, conv2_b=self.conv2.b, fc_W=self.fc.W, fc_b=self.fc.b, ) print(f"Model saved to {path}")
def load(self, path): data = np.load(path) self.conv1.W = data["conv1_W"] self.conv1.b = data["conv1_b"] self.conv2.W = data["conv2_W"] self.conv2.b = data["conv2_b"] self.fc.W = data["fc_W"] self.fc.b = data["fc_b"] print(f"Model loaded from {path}")使用示例:
# 保存model.save("models/cnn_model.npz")
# 加载model2 = MinimalCNNClassifier(lr=0.001, use_adam=True)model2.load("models/cnn_model.npz")
# 测试acc = test(model2, X_test, y_test) # 0.9806输出:
Model saved to models/cnn_model.npz!Reloading model...Model loaded from models/cnn_model.npzReload complete.Reloaded model accuracy: 0.9806单样本推理
给模型加上单样本推理功能:
class MinimalCNNClassifier: # ... 前面的代码 ...
def predict_batch(self, X): """批量预测""" y_pred = self.forward(X) return np.argmax(y_pred, axis=1)
def __call__(self, x): """ 单样本推理 接受 (28,28)、(28,28,1) 或 (1,28,28,1) 格式的输入 """ if x.ndim == 2: # (H, W) x = x[None, :, :, None] elif x.ndim == 3: # (H, W, C) x = x[None, :, :, :] return self.predict_batch(x)[0]使用示例:
# 单样本推理sample = X_test[0].squeeze() # (28, 28)pred = model(sample)true = np.argmax(y_test[0])print(f"Predicted: {pred}, True: {true}")输出:
Single sample predicted class: 7, true: 7为什么 CNN 比全连接好
让我们从数字上对比一下:
参数量
全连接网络:
- 第一层:784 × 128 = 100,352
- 第二层:128 × 10 = 1,280
- 总计:101,632 个参数
CNN:
- Conv1:8 × 1 × 3 × 3 = 72
- Conv2:16 × 8 × 3 × 3 = 1,152
- Dense:2304 × 10 = 23,040
- 总计:24,264 个参数
CNN 的参数量只有全连接的 24%。
准确率
- 全连接 (SGD):90.79%
- 全连接 (Adam):96.24%
- CNN (Adam):98.06%
CNN 比全连接 (Adam) 提升了 1.82 个百分点。
为什么 CNN 更好
- 利用空间结构:卷积核提取局部特征,保留了像素之间的空间关系。
- 平移不变性:同一个特征出现在不同位置都能被识别。
- 参数共享:一个卷积核在整个图像上滑动,参数量大大减少。
- 层次化特征:浅层提取边缘,深层提取形状和纹理。
可视化卷积核
我们可以看看第一层卷积核学到了什么:
import matplotlib.pyplot as plt
# 第一层卷积核 (8, 1, 3, 3)kernels = model.conv1.W.squeeze() # (8, 3, 3)
plt.figure(figsize=(12, 2))for i in range(8): plt.subplot(1, 8, i+1) plt.imshow(kernels[i], cmap='gray') plt.title(f"Filter {i+1}") plt.axis('off')plt.tight_layout()plt.show()第一层卷积核通常学到的是边缘检测器:水平边缘、垂直边缘、对角边缘等。
(这里我的图片忘记放了,然后还丢了)
完整代码
完整的 CNN 实现(包含 Numba 加速):
import numpy as npimport tqdmimport matplotlib.pyplot as pltimport osimport gzipfrom numba import njit
# ============================================================# Load MNIST from Kaggle idx files (local)# ============================================================def _open_maybe_gz(path): if os.path.exists(path): return open(path, "rb") if os.path.exists(path + ".gz"): return gzip.open(path + ".gz", "rb") raise FileNotFoundError(f"Cannot find {path} or {path+'.gz'}")
def load_mnist_from_local(data_dir): train_images_path = os.path.join(data_dir, "train-images-idx3-ubyte") train_labels_path = os.path.join(data_dir, "train-labels-idx1-ubyte") test_images_path = os.path.join(data_dir, "t10k-images-idx3-ubyte") test_labels_path = os.path.join(data_dir, "t10k-labels-idx1-ubyte")
# images: 16-byte header, then uint8 pixels with _open_maybe_gz(train_images_path) as f: data = np.frombuffer(f.read(), dtype=np.uint8, offset=16) X_train = data.reshape(-1, 28, 28, 1) / 255.0
with _open_maybe_gz(test_images_path) as f: data = np.frombuffer(f.read(), dtype=np.uint8, offset=16) X_test = data.reshape(-1, 28, 28, 1) / 255.0
# labels: 8-byte header, then uint8 labels with _open_maybe_gz(train_labels_path) as f: labels_train = np.frombuffer(f.read(), dtype=np.uint8, offset=8) with _open_maybe_gz(test_labels_path) as f: labels_test = np.frombuffer(f.read(), dtype=np.uint8, offset=8)
y_train = np.zeros((labels_train.size, 10)) y_train[np.arange(labels_train.size), labels_train] = 1
y_test = np.zeros((labels_test.size, 10)) y_test[np.arange(labels_test.size), labels_test] = 1
return X_train, y_train, X_test, y_test
# ============================================================# Basic ops# ============================================================def relu(x): return np.maximum(0, x)
def relu_grad(x): return (x > 0).astype(float)
def softmax(x): x_max = np.max(x, axis=1, keepdims=True) ex = np.exp(x - x_max) return ex / np.sum(ex, axis=1, keepdims=True)
def cross_entropy(y_pred, y_true): eps = 1e-15 y_pred = np.clip(y_pred, eps, 1 - eps) ce = -np.sum(y_true * np.log(y_pred), axis=1) return np.mean(ce)
def softmax_cross_entropy_grad(y_pred, y_true): return (y_pred - y_true) / y_true.shape[0]
def flatten(x): return x.reshape(x.shape[0], -1)
# ============================================================# Numba-accelerated convolution# ============================================================@njit(cache=True, fastmath=True)def conv_forward_jit(X, W, b, k): B, H, W_in, C = X.shape out_c = W.shape[0] out_h = H - k + 1 out_w = W_in - k + 1 out = np.zeros((B, out_h, out_w, out_c), dtype=X.dtype)
for b_idx in range(B): for oc in range(out_c): for i in range(out_h): for j in range(out_w): acc = 0.0 for ic in range(C): for ki in range(k): for kj in range(k): acc += X[b_idx, i + ki, j + kj, ic] * W[oc, ic, ki, kj] out[b_idx, i, j, oc] = acc + b[oc, 0] return out
@njit(cache=True, fastmath=True)def conv_backward_jit(X, grad, W, k): B, H, W_out, out_c = grad.shape _, XH, XW, in_c = X.shape
dW = np.zeros_like(W) db = np.zeros((out_c, 1), dtype=grad.dtype) dX = np.zeros_like(X)
for b_idx in range(B): for oc in range(out_c): for ic in range(in_c): for i in range(H): for j in range(W_out): g = grad[b_idx, i, j, oc] if g == 0.0: continue for ki in range(k): for kj in range(k): dW[oc, ic, ki, kj] += g * X[b_idx, i + ki, j + kj, ic] dX[b_idx, i + ki, j + kj, ic] += g * W[oc, ic, ki, kj] db[oc, 0] += np.sum(grad[b_idx, :, :, oc]) return dW, db, dX
# ============================================================# Numba-accelerated max pooling# ============================================================@njit(cache=True, fastmath=True)def maxpool_forward_jit(X): B, H, W, C = X.shape out = np.zeros((B, H // 2, W // 2, C), dtype=X.dtype) argmax = np.zeros((B, H // 2, W // 2, C), dtype=np.int64)
for b_idx in range(B): for c in range(C): for i in range(0, H, 2): for j in range(0, W, 2): max_val = X[b_idx, i, j, c] max_idx = 0 idx = 1 v = X[b_idx, i, j + 1, c] if v > max_val: max_val = v max_idx = idx idx += 1 v = X[b_idx, i + 1, j, c] if v > max_val: max_val = v max_idx = idx idx += 1 v = X[b_idx, i + 1, j + 1, c] if v > max_val: max_val = v max_idx = idx out[b_idx, i // 2, j // 2, c] = max_val argmax[b_idx, i // 2, j // 2, c] = max_idx return out, argmax
@njit(cache=True, fastmath=True)def maxpool_backward_jit(grad, argmax, H_in, W_in): B, H2, W2, C = grad.shape dX = np.zeros((B, H_in, W_in, C), dtype=grad.dtype)
for b_idx in range(B): for c in range(C): for i in range(H2): for j in range(W2): idx = argmax[b_idx, i, j, c] bi = idx // 2 bj = idx - bi * 2 dX[b_idx, i * 2 + bi, j * 2 + bj, c] = grad[b_idx, i, j, c] return dX
# ============================================================# Convolution Layer# ============================================================class Conv2D: def __init__(self, in_channels, out_channels, kernel_size): self.in_c = in_channels self.out_c = out_channels self.k = kernel_size self.W = np.random.randn(out_channels, in_channels, kernel_size, kernel_size) * 0.1 self.b = np.zeros((out_channels, 1))
def forward(self, X): self.X = X self.out = conv_forward_jit(X, self.W, self.b, self.k) return self.out
def backward(self, grad): self.dW, self.db, self.dX = conv_backward_jit(self.X, grad, self.W, self.k) return self.dX
def step(self, lr): self.W -= lr * self.dW self.b -= lr * self.db
# ============================================================# Max Pooling Layer# ============================================================class MaxPool2x2: def forward(self, X): self.X = X out, self.argmax = maxpool_forward_jit(X) return out
def backward(self, grad): B, H, W, C = self.X.shape return maxpool_backward_jit(grad, self.argmax, H, W)
# ============================================================# Dense Layer# ============================================================class Dense: def __init__(self, in_dim, out_dim): self.W = np.random.randn(in_dim, out_dim) * 0.01 self.b = np.zeros((1, out_dim))
def forward(self, X): self.X = X self.out = X @ self.W + self.b return self.out
def backward(self, grad): self.dW = self.X.T @ grad self.db = np.sum(grad, axis=0, keepdims=True) return grad @ self.W.T
def step(self, lr): self.W -= lr * self.dW self.b -= lr * self.db
# ============================================================# CNN Classifier# ============================================================class MinimalCNNClassifier: def __init__(self, lr=0.01, use_adam=False, hidden_dim=44): self.lr = lr self.use_adam = use_adam self.adam = TinyAdam(lr=lr) if use_adam else None self.conv1 = Conv2D(1, 8, 3) self.conv2 = Conv2D(8, 16, 3) self.pool = MaxPool2x2() self.hidden_dim = hidden_dim # choose 44 -> ~103K params total self.fc1 = Dense(12 * 12 * 16, hidden_dim) self.fc2 = Dense(hidden_dim, 10)
def forward(self, X): out = self.conv1.forward(X) out = relu(out) self.after_relu1 = out
out = self.conv2.forward(out) out = relu(out) self.after_relu2 = out
out = self.pool.forward(out) out = flatten(out)
out = self.fc1.forward(out) self.fc1_pre_relu = out out = relu(out) self.fc1_post_relu = out
out = self.fc2.forward(out) self.y_pred = softmax(out) return self.y_pred
def backward(self, y_true): grad = softmax_cross_entropy_grad(self.y_pred, y_true)
grad = self.fc2.backward(grad)
grad = grad * relu_grad(self.fc1_pre_relu) grad = self.fc1.backward(grad)
grad = grad.reshape(-1, 12, 12, 16) grad = self.pool.backward(grad)
grad = grad * relu_grad(self.after_relu2) grad = self.conv2.backward(grad)
grad = grad * relu_grad(self.after_relu1) self.conv1.backward(grad)
def step(self): if self.adam: self.adam.step([ (self.fc2.W, self.fc2.dW), (self.fc2.b, self.fc2.db), (self.fc1.W, self.fc1.dW), (self.fc1.b, self.fc1.db), (self.conv2.W, self.conv2.dW), (self.conv2.b, self.conv2.db), (self.conv1.W, self.conv1.dW), (self.conv1.b, self.conv1.db), ]) else: self.fc2.step(self.lr) self.fc1.step(self.lr) self.conv2.step(self.lr) self.conv1.step(self.lr)
def predict_batch(self, X): """Return predicted classes for a batch.""" y_pred = self.forward(X) return np.argmax(y_pred, axis=1)
def __call__(self, x): """ Single-sample predict wrapper. Accepts (28,28), (28,28,1) or (1,28,28,1) style inputs. """ if x.ndim == 2: # H, W x = x[None, :, :, None] elif x.ndim == 3: # H, W, C x = x[None, :, :, :] return self.predict_batch(x)[0]
def save(self, path): os.makedirs(os.path.dirname(path), exist_ok=True) np.savez( path, conv1_W=self.conv1.W, conv1_b=self.conv1.b, conv2_W=self.conv2.W, conv2_b=self.conv2.b, fc1_W=self.fc1.W, fc1_b=self.fc1.b, fc2_W=self.fc2.W, fc2_b=self.fc2.b, )
def load(self, path): data = np.load(path) self.conv1.W = data["conv1_W"] self.conv1.b = data["conv1_b"] self.conv2.W = data["conv2_W"] self.conv2.b = data["conv2_b"] self.fc1.W = data["fc1_W"] self.fc1.b = data["fc1_b"] self.fc2.W = data["fc2_W"] self.fc2.b = data["fc2_b"]
# ============================================================# Tiny Adam Optimizer# ============================================================class TinyAdam: """ 超简洁 Adam:以 (param, grad) 列表作为输入,直接原地更新参数。 """ def __init__(self, lr=0.001, beta1=0.9, beta2=0.999, eps=1e-8): self.lr = lr self.beta1 = beta1 self.beta2 = beta2 self.eps = eps self.t = 0 self.m = {} self.v = {}
def step(self, params_and_grads): self.t += 1 for param, grad in params_and_grads: key = id(param) m = self.m.get(key, np.zeros_like(param)) v = self.v.get(key, np.zeros_like(param))
m = self.beta1 * m + (1 - self.beta1) * grad v = self.beta2 * v + (1 - self.beta2) * (grad ** 2)
m_hat = m / (1 - self.beta1 ** self.t) v_hat = v / (1 - self.beta2 ** self.t)
param -= self.lr * m_hat / (np.sqrt(v_hat) + self.eps)
self.m[key] = m self.v[key] = v
# ============================================================# Train Helpers# ============================================================def train(model, X_train, y_train, epochs, batch_size): loss_list = [] for epoch in range(epochs): pbar = tqdm.tqdm( range(0, len(X_train), batch_size), desc=f"Epoch {epoch+1}/{epochs}" )
for i in pbar: X_batch = X_train[i:i+batch_size] y_batch = y_train[i:i+batch_size]
y_pred = model.forward(X_batch) loss = cross_entropy(y_pred, y_batch) loss_list.append(loss)
model.backward(y_batch) model.step()
pbar.set_postfix({"loss": f"{loss:.4f}"}) return loss_list
def test(model, X_test, y_test): y_pred_test = model.forward(X_test) pred = np.argmax(y_pred_test, axis=1) true = np.argmax(y_test, axis=1) acc = (pred == true).mean() print(f"\nTest accuracy: {acc:.4f}") return acc
def plot_loss(loss_list): plt.figure(figsize=(6,4)) plt.plot(loss_list) plt.xlabel("Iteration") plt.ylabel("Loss (CE)") plt.yscale("log") plt.title("Minimal CNN Training Loss") plt.grid(True) plt.tight_layout() plt.show()# ============================================================# Train + Test# ============================================================if __name__ == "__main__": np.random.seed(0)
# change to your dataset directory data_dir = "./mnist"
# 1. Load MNIST (from your Kaggle files) X_train, y_train, X_test, y_test = load_mnist_from_local(data_dir)
# 2. Create CNN model model = MinimalCNNClassifier(lr=0.001, use_adam=True)
# 3. Train loss_list = train(model, X_train, y_train, epochs=3, batch_size=64) # CNN 收敛快,3 epoch 就能 90%+
# 4. Test accuracy acc = test(model, X_test, y_test)
# 5. Plot loss plot_loss(loss_list)
# 6. Save model save_path = os.path.join("models", "cnn_model.npz") model.save(save_path) print(f"Model saved to {save_path}!")
# 7. Load model (optional) print("Reloading model...") model2 = MinimalCNNClassifier(lr=0.001, use_adam=True) model2.load(save_path) print("Reload complete.")
# 8. Test reloaded model acc2 = test(model2, X_test, y_test) print(f"Reloaded model accuracy: {acc2:.4f}")
# 9. Single-sample inference demo sample_pred = model2(X_test[0].squeeze()) print(f"Single sample predicted class: {sample_pred}, true: {np.argmax(y_test[0])}")