-
Notifications
You must be signed in to change notification settings - Fork 2
/
spike2file.hoc
125 lines (120 loc) · 2.65 KB
/
spike2file.hoc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
// assume pc exists
proc sortspikes() {local res, i, j, k, imin, x, nbin, tmin, tmax, tvl, spikecount \
localobj srt, s1, s2, cnts, d1, d2, gs
s1 = $o1
s2 = $o2
d1 = $o3
d2 = $o4
nbin = $5 // number of ranks (from 0) to do the sorting
if (nbin > pc.nhost) { nbin = pc.nhost }
// for testing, round to the file output resolution
res = .00001
s1.div(res).add(.5).floor().mul(res)
spike_count = pc.allreduce(s1.size, 1)
if (pc.id == 0) {printf("spike_count %d, sort using %d ranks\n", spike_count, nbin)}
if (spike_count == 0) { return }
// not always sorted even on per processor basis
srt = s1.sortindex
s1.index(s1, srt)
s2.index(s2, srt)
if (pc.nhost == 1) {
// should also sort gid
d1.copy(s1)
d2.copy(s2)
return
}
// min and max spike time
if (s1.size) {
tmin = s1.x[0]
tmax = s1.x[s1.size-1]
}else{
tmin = 1e9
tmax = 0
}
tmin = pc.allreduce(tmin, 3)
tmax = pc.allreduce(tmax, 2) + 1
// exchange
tvl = (tmax - tmin)/nbin
cnts = new Vector(pc.nhost)
j = 0
for i=0, nbin - 1 {
x = tmin + (i+1)*tvl
k = 0
while (j < s1.size) {
if (s1.x[j] < x) {
j += 1
k += 1
}else{
break
}
}
cnts.x[i] = k
}
pc.alltoall(s1, cnts, d1)
pc.alltoall(s2, cnts, d2)
if (d1.size) {
srt = d1.sortindex
d1.index(d1, srt)
d2.index(d2, srt)
// now sort the gids without destroying the spiketime sort
gs = new Vector()
imin = 0
n = d1.size
for i=1, n {
if (i < n) if (d1.x[imin] == d1.x[i]) {
continue
}
if (i - imin > 1) {
gs.resize(0)
gs.copy(d2, imin, i-1)
gs.sort
d2.copy(gs, imin, 0, gs.size-1)
}
imin = i
}
}
pc.barrier()
}
spike2file_time = 0
spike2file_ncall = 0
proc spike2file() { local i, j, nf, me, ts, nbin localobj outf, s, vs, vg
ts = startsw()
nbin = 16
if (numarg() == 3) { nbin = $3 }
if (nbin > pc.nhost) {
nbin = pc.nhost
}
vs = new Vector()
vg = new Vector()
sortspikes($o1, $o2, vs, vg, nbin)
i = pc.allreduce(vs.size, 1)
if (pc.id == 0) printf("spike2file %d\n", i)
s = new String()
nf = nbin // number of contiguous processes that write to one file
if (nf > pc.nhost) { nf = pc.nhost }
me = pc.id%nf // my id relative to the nf group
// nothing beyond nbin will write a file
me = pc.id
sprint(s.s, "spk%03d.dat", int(pc.id/nf))
outf = new File()
for j=0, nf-1 {
if (j == me) {
if (j == 0 && spike2file_ncall == 0) {
outf.wopen(s.s)
outf.close()
}
outf.aopen(s.s)
for i=0, vs.size-1 {
outf.printf("%.5f %d\n", vs.x[i], vg.x[i])
}
outf.close
}
pc.barrier
}
$o1.resize(0)
$o2.resize(0)
spike2file_ncall += 1
ts = startsw() - ts
if (pc.id == 0) printf("spike2file call#%d %g\n", spike2file_ncall, ts)
spike2file_time += ts
}