Bootstrap

【陈工笔记】SNN(Spiking Neural Network)的理解

(不足之处较多,望包涵。用于自学,帮助理解。)

1、LIFNode在网络架构中的简单应用例子

    # 定义并初始化网络
    net = nn.Sequential(
        nn.Flatten(),
        nn.Linear(28 * 28, 14 * 14, bias=False), #这里不加bias应该是偏置在SNN中不好表示
        neuron.LIFNode(tau=tau),
        nn.Linear(14 * 14, 10, bias=False),
        neuron.LIFNode(tau=tau)
    )

注意,neuron.LIFNode的作用是将14*14的模拟连续的神经元变为LIF神经元,按0,1发放脉冲,整个网络也就变成了SNN

2、以LIFNode生成Net之后,在训练过程中的简单应用例子

训练部分即按一次batchsize个样本进行权重的更新,更新之后LIF神经元部分需要重置

    for epoch in range(train_epoch):
        net.train()
        for img, label in tqdm(train_data_loader):
            img = img.to(device)
            label = label.to(device)
            label_one_hot = F.one_hot(label, 10).float()

            optimizer.zero_grad()

            # 运行T个时长,out_spikes_counter是shape=[batch_size, 10]的tensor
            # 记录整个仿真时长内,输出层的10个神经元的脉冲发放次数
            for t in range(T):
                if t == 0:
                    out_spikes_counter = net(encoder(img).float())
                else:
                    out_spikes_counter += net(encoder(img).float())

            # out_spikes_counter / T 得到输出层10个神经元在仿真时长内的脉冲发放频率
            out_spikes_counter_frequency = out_spikes_counter / T

            # 损失函数为输出层神经元的脉冲发放频率,与真实类别的MSE
            loss = F.mse_loss(out_spikes_counter_frequency, label_one_hot)
            loss.backward()
            optimizer.step()
            # 优化一次参数后,需要重置网络的状态,因为SNN的神经元是有“记忆”的
            functional.reset_net(net)

            # 正确率的计算方法如下。认为输出层中脉冲发放频率最大的神经元的下标i是分类结果
            accuracy = (out_spikes_counter_frequency.max(1)[1] == label.to(device)).float().mean().item()
            
            writer.add_scalar('train_accuracy', accuracy, train_times)
            train_accs.append(accuracy)

            train_times += 1

3、一个epoch参数更新后在整个测试集上进行模型测试,方便训练时看效果

        net.eval()
        with torch.no_grad():
            # 每遍历一次全部数据集,就在测试集上测试一次
            test_sum = 0
            correct_sum = 0
            for img, label in test_data_loader:
                img = img.to(device)
                for t in range(T):
                    if t == 0:
                        out_spikes_counter = net(encoder(img).float())
                    else:
                        out_spikes_counter += net(encoder(img).float())

                correct_sum += (out_spikes_counter.max(1)[1] == label.to(device)).float().sum().item()
                test_sum += label.numel()
                functional.reset_net(net)
            test_accuracy = correct_sum / test_sum
            writer.add_scalar('test_accuracy', test_accuracy, epoch)
            test_accs.append(test_accuracy)
            max_test_accuracy = max(max_test_accuracy, test_accuracy)
        print(f'Epoch {epoch}: device={device}, dataset_dir={dataset_dir}, batch_size={batch_size}, learning_rate={learning_rate}, T={T}, log_dir={log_dir}, max_test_accuracy={max_test_accuracy}, train_times={train_times}')

4、注意:T指的是仿真时长,T越大,实验时间越大。所需显存数量与仿真时长 T 线性相关,更长的 T 相当于使用更小的仿真步长,训练更为“精细”,但训练效果不一定更好,因此 T 太大,SNN在时间上展开后就会变成一个非常深的网络,梯度的传递容易衰减或爆炸。由于我们使用了泊松编码器,因此需要较大的 T。 

不断更新ing

;