(不足之处较多,望包涵。用于自学,帮助理解。)
1、LIFNode在网络架构中的简单应用例子
# 定义并初始化网络
net = nn.Sequential(
nn.Flatten(),
nn.Linear(28 * 28, 14 * 14, bias=False), #这里不加bias应该是偏置在SNN中不好表示
neuron.LIFNode(tau=tau),
nn.Linear(14 * 14, 10, bias=False),
neuron.LIFNode(tau=tau)
)
注意,neuron.LIFNode
的作用是将14*14的模拟连续的神经元变为LIF神经元,按0,1发放脉冲,整个网络也就变成了SNN
2、以LIFNode生成Net之后,在训练过程中的简单应用例子
训练部分即按一次batchsize
个样本进行权重的更新,更新之后LIF神经元部分需要重置
for epoch in range(train_epoch):
net.train()
for img, label in tqdm(train_data_loader):
img = img.to(device)
label = label.to(device)
label_one_hot = F.one_hot(label, 10).float()
optimizer.zero_grad()
# 运行T个时长,out_spikes_counter是shape=[batch_size, 10]的tensor
# 记录整个仿真时长内,输出层的10个神经元的脉冲发放次数
for t in range(T):
if t == 0:
out_spikes_counter = net(encoder(img).float())
else:
out_spikes_counter += net(encoder(img).float())
# out_spikes_counter / T 得到输出层10个神经元在仿真时长内的脉冲发放频率
out_spikes_counter_frequency = out_spikes_counter / T
# 损失函数为输出层神经元的脉冲发放频率,与真实类别的MSE
loss = F.mse_loss(out_spikes_counter_frequency, label_one_hot)
loss.backward()
optimizer.step()
# 优化一次参数后,需要重置网络的状态,因为SNN的神经元是有“记忆”的
functional.reset_net(net)
# 正确率的计算方法如下。认为输出层中脉冲发放频率最大的神经元的下标i是分类结果
accuracy = (out_spikes_counter_frequency.max(1)[1] == label.to(device)).float().mean().item()
writer.add_scalar('train_accuracy', accuracy, train_times)
train_accs.append(accuracy)
train_times += 1
3、一个epoch参数更新后在整个测试集上进行模型测试,方便训练时看效果
net.eval()
with torch.no_grad():
# 每遍历一次全部数据集,就在测试集上测试一次
test_sum = 0
correct_sum = 0
for img, label in test_data_loader:
img = img.to(device)
for t in range(T):
if t == 0:
out_spikes_counter = net(encoder(img).float())
else:
out_spikes_counter += net(encoder(img).float())
correct_sum += (out_spikes_counter.max(1)[1] == label.to(device)).float().sum().item()
test_sum += label.numel()
functional.reset_net(net)
test_accuracy = correct_sum / test_sum
writer.add_scalar('test_accuracy', test_accuracy, epoch)
test_accs.append(test_accuracy)
max_test_accuracy = max(max_test_accuracy, test_accuracy)
print(f'Epoch {epoch}: device={device}, dataset_dir={dataset_dir}, batch_size={batch_size}, learning_rate={learning_rate}, T={T}, log_dir={log_dir}, max_test_accuracy={max_test_accuracy}, train_times={train_times}')
4、注意:T指的是仿真时长,T越大,实验时间越大。所需显存数量与仿真时长 T 线性相关,更长的 T 相当于使用更小的仿真步长,训练更为“精细”,但训练效果不一定更好,因此 T 太大,SNN在时间上展开后就会变成一个非常深的网络,梯度的传递容易衰减或爆炸。由于我们使用了泊松编码器,因此需要较大的 T。
不断更新ing