-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathmulti_neuron_learning.py
executable file
·129 lines (106 loc) · 3.89 KB
/
multi_neuron_learning.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
####################################################### README ####################################################################
# This is the main file which calls all the functions and trains the network by updating weights
#####################################################################################################################################
import numpy as np
from neuron import neuron
import random
from matplotlib import pyplot as plt
from recep_field import rf
import imageio
from spike_train import encode
from rl import *
from reconstruct import reconst_weights
from parameters import *
import os
import pandas as pd
import time
pot_arrays = []
pth_arrays = []
for i in range(n):
pot_arrays.append([])
pth_arrays.append([])
#time series
time_of_learning = np.arange(1, T+1, 1)
output_layer = []
# creating the hidden layer of neurons
for i in range(n):
a = neuron()
a.initial()
output_layer.append(a)
#Random synapse matrix initialization
synapse = synapse_init
synapse_memory=np.zeros((n,m))
#Creating labels corresponding to neuron
label_neuron=np.repeat(-1,n)
for k in range(epoch):
print(k)
for folder in os.listdir('./mnist_png/training/'):
for i in os.listdir("./mnist_png/training/"+folder+"/")[:80]:
t0=time.time()
print(i , " : ")
img = imageio.imread("./mnist_png/training/"+folder+"/"+i)
#Convolving image with receptive field and encoding to generate spike train
train = np.array(encode(rf(img)))
#Local variables
winner = False
count_spikes= np.zeros(n)
active_pot = np.zeros(n)
#Leaky integrate and fire neuron dynamics
for t in time_of_learning:
for j, x in enumerate(output_layer):
if(x.t_rest<t):
x.P = x.P + np.dot(synapse[j], train[:,t])
if(x.P>Prest):
x.P -= Pdrop
if(x.Pth > Pth):
x.Pth -= Pthdrop
active_pot[j] = x.P
pot_arrays[j].append(x.P) # Only for plotting: Changing potential overtime
pth_arrays[j].append(x.Pth) # Only for plotting: Changing threshold overtime
winner = np.argmax(active_pot)
#Check for spikes and update weights
for j,x in enumerate(output_layer):
if(j==winner and active_pot[j]>output_layer[j].Pth):
x.hyperpolarization(t)
x.Pth-= -1 ## Adaptive Membrane/Homoeostasis: Increasing the threshold of the neuron
count_spikes[j]+=1
for h in range(m):
for t1 in range(0,t_back-1, -1): # if presynaptic spike came before postsynaptic spike
if 0<=t+t1<T+1:
if train[h][t+t1] == 1: # if presynaptic spike was in the tolerance window
synapse[j][h] = update(synapse[j][h], rl(t1)) # strengthen weights
synapse_memory[j][h]=1
break
if synapse_memory[j][h]!=1: # if presynaptic spike was not in the tolerance window, reduce weights of that synapse
synapse[j][h] = update(synapse[j][h], rl(1))
for p in range(n):
if p!=winner:
if(output_layer[p].P>output_layer[p].Pth):
count_spikes[p]+=1
output_layer[p].inhibit(t)
break
# bring neuron potentials to rest
for p in range(n):
output_layer[p].initial()
label_neuron[winner]=int(folder)
#print("Image: "+i+" Spike COunt = ",count_spikes)
print("Learning Neuron: ",np.argmax(count_spikes))
print("Learning duration: ", time.time()-t0)
# to write intermediate synapses for neurons
#for p in range(n):
# reconst_weights(synapse[p],str(p)+"_epoch_"+str(k))
# Plotting
# ttt = np.arange(0,len(pot_arrays[0]),1)
# for i in range(n):
# axes = plt.gca()
# plt.plot(ttt,pth_arrays[i], 'r' )
# plt.plot(ttt,pot_arrays[i])
# plt.show()
#Reconstructing weights to analyse training
for i in range(n):
if label_neuron[i]==-1 :
for j in range(m):
synapse[i][j]=0
reconst_weights(synapse[i],str(i)+"_final")
np.savetxt("weights.csv", synapse, delimiter=",")
np.savetxt("labels.csv",label_neuron,delimiter=',')