SNN快速指北

基于Brain2的SNN模块化代码

#包引用

1
from brian2 import *

#参数设置

1
2
3
tau_default = 10*ms
v_rest = -70*mV
v_thresh = -50*mV

STDP参数

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
# 添加 STDP 的突触模型

stdp_eqs = '''
w : 1
dapre/dt = -apre / tau_pre : 1 (event-driven)
dapost/dt = -apost / tau_post : 1 (event-driven)
'''

on_pre = '''
v_post += w  # 权重控制电压增加
apre += A_pre
w = clip(w + apost, 0, 1)
'''

on_post = '''
apost += A_post
w = clip(w + apre, 0, 1)
'''

#定义神经元模型的微分方程 LIF

1
2
3
eqs = '''
dv/dt = (v_rest - v)/tau : volt     # 单位伏特
'''

H-H

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
eqs_HH = '''
dv/dt = (I - gNa*m**3*h*(v - ENa) - gK*n**4*(v - EK) - gl*(v - El))/Cm : volt
dm/dt = alpham*(1 - m) - betam*m : 1
dn/dt = alphan*(1 - n) - betan*n : 1
dh/dt = alphah*(1 - h) - betah*h : 1

alpham = 0.1*(mV**-1)*(v + 40*mV)/(1 - exp(-(v + 40*mV)/(10*mV)))/ms : Hz
betam  = 4*exp(-(v + 65*mV)/(18*mV))/ms : Hz
alphah = 0.07*exp(-(v + 65*mV)/(20*mV))/ms : Hz
betah  = 1/(1 + exp(-(v + 35*mV)/(10*mV)))/ms : Hz
alphan = 0.01*(mV**-1)*(v + 55*mV)/(1 - exp(-(v + 55*mV)/(10*mV)))/ms : Hz
betan  = 0.125*exp(-(v + 65*mV)/(80*mV))/ms : Hz

I : amp
'''

#构建神经元组 LIF 以输入层为例1

1
2
input_group = NeuronGroup(5, eqs, threshold='v > v_thresh', reset='v = v_rest', method='exact')
input_group.v = v_rest  # 初始电压设置为静息电压

对于neurongroup函数,对应的项分别是: 细胞数, 遵循的模型方程, 阈值, 解析方式(exact为直接解析,适用于有解析解的函数; Euler, 适用于无解析的函数)等

H-H

1
2
3
4
5
G = NeuronGroup(1, eqs_HH, method='exponential_euler',
                threshold='v > -40*mV', reset='v = -65*mV',
                namespace={'Cm': 1*uF, 'gNa': 120*msiemens, 'ENa': 50*mV,
                           'gK': 36*msiemens, 'EK': -77*mV,
                           'gl': 0.3*msiemens, 'El': -54.387*mV})

#创建输入刺激

  1. 模式一: 以神经元组为单位设置初始电压或电流
1
2
3
4
5
G_input.I = 1.2

G_hidden.I = 0.0

G_output.I = 0.0
1
G_inh.v = 1.2  # 初始也会放电一次(你也可以改成 PoissonGroup 随机发放)
  1. 模式二: 精细调控输入层神经元每个神经元的发放时间
1
2
3
4
5
6
7
# 创建输入刺激:让第 0、1、2 个神经元在 10, 20, 30ms 依次放电
input_indices = [0, 1, 2]
input_times = [10, 20, 30]*ms

stimulus = SpikeGeneratorGroup(3, input_indices, input_times)  # `period`: 设置重复周期(比如每 100ms 重复一次这组事件);`dt`: 时间步长控制(可选)
input_connection = Synapses(stimulus, input_group, on_pre='v_post += 2*mV')
input_connection.connect(j='i')  # 每个刺激连接到对应神经元
  1. 模式三: …

#构建突触

  1. 普通突触
1
2
syn_input_hidden = Synapses(input_group, hidden_group, on_pre='v_post += 1.5*mV')
syn_input_hidden.connect(p=0.5)  # 以50%的概率随机连接
  1. STDP
1
2
3
4
5
6
7
S_hidden_output = Synapses(G_hidden, G_output,
                           model=stdp_eqs,
                           on_pre=on_pre,
                           on_post=on_post)

S_hidden_output.connect(p=0.5)  # 随机连接
S_hidden_output.w = 'rand() * 0.2'  # 初始权重较小

#记录运行

1
2
3
4
5
6
7
# 创建记录器
mon_input = StateMonitor(input_group, 'v', record=True)
mon_hidden = StateMonitor(hidden_group, 'v', record=True)
mon_output = StateMonitor(output_group, 'v', record=True)

# 运行仿真
run(100*ms) #运行时间

#可视化

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
figure(figsize=(12, 4))

subplot(1, 3, 1)
plot(mon_input.t/ms, mon_input.v[0]/mV)
title('输入神经元 0')
xlabel('时间 (ms)')
ylabel('电压 (mV)')

subplot(1, 3, 2)
plot(mon_hidden.t/ms, mon_hidden.v[0]/mV)
title('隐藏神经元 0')

subplot(1, 3, 3)
plot(mon_output.t/ms, mon_output.v[0]/mV)
title('输出神经元 0')

tight_layout()
show()

  1. 实际上基于brain2的神经元在构建时的代码别无二致,真正决定功能的是突触的定义 ↩︎

Licensed under CC BY-NC-SA 4.0
comments powered by Disqus
Built with Hugo
Theme Stack designed by Jimmy