forked from abhishekde95/Code
-
Notifications
You must be signed in to change notification settings - Fork 0
/
2PResponseFourierRF.jl
280 lines (250 loc) · 11.2 KB
/
2PResponseFourierRF.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
# Peichao's Notes:
# 1. Code was written for 2P data (Hartley) from Scanbox. Will export results (dataframe and csv) for plotting.
# 2. If you have multiple planes, it works with splited & interpolated dat. Note results are slightly different.
# 3. If you have single plane, set interpolatedData as false.
using NeuroAnalysis,Statistics,DataFrames,DataFramesMeta,StatsPlots,Mmap,LinearAlgebra,Images,StatsBase,Interact, CSV,MAT, DataStructures, HypothesisTests, StatsFuns, Random, Plots
# Expt info
disk = "O:"
subject = "AF4" # Animal
recordSession = "004" # Unit
testId = "004" # Stimulus test
interpolatedData = true # If you have multiplanes. True: use interpolated data; false: use uniterpolated data. Results are slightly different.
delays = -0.066:0.033:0.4
ntau = length(collect(delays))
print(collect(delays))
isplot = false
## Prepare data & result path
exptId = join(filter(!isempty,[recordSession, testId]),"_")
dataFolder = joinpath(disk,subject, "2P_data", join(["U",recordSession]), exptId)
metaFolder = joinpath(disk,subject, "2P_data", join(["U",recordSession]), "metaFiles")
## load expt, scanning parameters
metaFile=matchfile(Regex("[A-Za-z0-9]*_[A-Za-z0-9]*_$testId*_ot_meta.mat"),dir=metaFolder,join=true)[1]
dataset = prepare(metaFile)
ex = dataset["ex"]
envparam = ex["EnvParam"]
coneType = getparam(envparam,"colorspace")
szhtly_visangle = envparam["x_size"] # deg
maxSF = envparam["max_sf"] # cyc/deg
sbx = dataset["sbx"]["info"]
sbxft = ex["frameTimeSer"] # time series of sbx frame in whole recording
# Condition Tests
envparam = ex["EnvParam"];preicidur = ex["PreICI"];conddur = ex["CondDur"];suficidur = ex["SufICI"]
condon = ex["CondTest"]["CondOn"]
condoff = ex["CondTest"]["CondOff"]
condidx = ex["CondTest"]["CondIndex"]
nstim = size(condidx,1)
# condtable = DataFrame(ex["Cond"])
condtable = DataFrame(ex["raw"]["log"]["randlog_T1"]["domains"]["Cond"])
rename!(condtable, [:oridom, :kx, :ky,:bwdom,:colordom])
condtable[:kx] = [Int(x) for x in condtable[:kx]]
condtable[:ky] = [Int(x) for x in condtable[:ky]]
max_k = max(abs.(condtable.kx)...)
# find out blanks and unique conditions
blkidx = condidx.>5641 # blanks start from 5641
cidx = .!blkidx
condidx2 = condidx.*cidx + blkidx.* 5641
conduniq = unique(condidx2)
## Load data
if interpolatedData
segmentFile=matchfile(Regex("[A-Za-z0-9]*[A-Za-z0-9]*_merged.segment"),dir=dataFolder,join=true)[1]
signalFile=matchfile(Regex("[A-Za-z0-9]*[A-Za-z0-9]*_merged.signals"),dir=dataFolder,join=true)[1]
else
segmentFile=matchfile(Regex("[A-Za-z0-9]*[A-Za-z0-9]*.segment"),dir=dataFolder,join=true)[1]
signalFile=matchfile(Regex("[A-Za-z0-9]*[A-Za-z0-9].signals"),dir=dataFolder,join=true)[1]
end
segment = prepare(segmentFile)
signal = prepare(signalFile)
# sig = transpose(signal["sig"]) # 1st dimention is cell roi, 2nd is fluorescence trace
spks = transpose(signal["spks"]) # 1st dimention is cell roi, 2nd is spike train
## Load data
planeNum = size(segment["mask"],3) # how many planes
if interpolatedData
planeStart = vcat(1, length.(segment["seg_ot"]["vert"]).+1)
end
## Use for loop process each plane seperately
for pn in 1:planeNum
# pn=2 # for test
# Initialize DataFrame for saving results
recordPlane = string("00",pn-1) # plane/depth, this notation only works for expt has less than 10 planes
siteId = join(filter(!isempty,[recordSession, testId, recordPlane]),"_")
dataExportFolder = joinpath(disk,subject, "2P_analysis", join(["U",recordSession]), siteId, "DataExport")
resultFolder = joinpath(disk,subject, "2P_analysis", join(["U",recordSession]), siteId, "Plots")
isdir(dataExportFolder) || mkpath(dataExportFolder)
isdir(resultFolder) || mkpath(resultFolder)
result = DataFrame()
if interpolatedData
cellRoi = segment["seg_ot"]["vert"][pn]
else
cellRoi = segment["vert"]
end
cellNum = length(cellRoi)
display("plane: $pn")
display("Cell Number: $cellNum")
if interpolatedData
# rawF = sig[planeStart[pn]:planeStart[pn]+cellNum-1,:]
spike = spks[planeStart[pn]:planeStart[pn]+cellNum-1,:]
else
# rawF = sig
spike = spks
end
result.py = 0:cellNum-1
result.ani = fill(subject, cellNum)
result.dataId = fill(siteId, cellNum)
result.cellId = 1:cellNum
## Chop spk trains according to delays
spk=zeros(nstim,ntau,cellNum)
for d in eachindex(delays)
y,num,wind,idx = epochspiketrain(sbxft,condon.+delays[d], condoff.+delays[d],isminzero=false,ismaxzero=false,shift=0,israte=false)
for i =1:nstim
spkepo = @view spike[:,idx[i][1]:idx[i][end]]
spk[i,d,:]= mean(spkepo, dims=2)
end
end
## Sum cell response of different repeats
r = zeros(2*max_k+1,2*max_k+1,ntau,cellNum) # Blank condition [0,0] is now at [max_k+1, max_k+1]
for i=1:nstim
r[-condtable.kx[condidx[i]]+1+max_k,condtable.ky[condidx[i]]+1+max_k,:,:] = r[-condtable.kx[condidx[i]]+1+max_k,condtable.ky[condidx[i]]+1+max_k,:,:]+spk[i,:,:]
end
# Normalize by stim repeats.
reps = zeros(size(conduniq,1))
for i in 1:size(conduniq,1)
rep = length(findall(x->x==conduniq[i],condidx2))
reps[i] = rep
r[-condtable.kx[conduniq[i]]+1+max_k,condtable.ky[conduniq[i]]+1+max_k,:,:] ./= rep
end
## Filter 2D tuning map
for t = 1:ntau
for n = 1:cellNum
rf = r[:,:,t,n]
rf = rf + rot180(rf) # average over phases, PL
r[:,:,t,n] = imfilter(rf,Kernel.gaussian((1,1),(3,3)))
end
end
## PL: Build a complax plane with the same size as Hartley space (-kxmax:kxmax, -kymax:kymax) for sf and ori estimation
szhtly = 2*max_k+1
vect = collect(-max_k:max_k)
xx = repeat(vect',szhtly,1)
yy = repeat(reverse(vect),1,szhtly)
zz= xx + 1im .* yy
# heatmap(angle.(zz),yflip=true) # -pi to pi
## find best kernel and estimate preferred sf and ori
taumax=[];kstd=[];kstdmax=[];kernraw=[];kernnor=[];kernest=[];
kdelta=[];signif=[];slambda=[];sfmax=[];orimax=[];sfmean=[];orimean=[];
sfLevel=[];sfResp=[];oriLevel=[];oriResp=[];
for i = 1:cellNum
# i=438
# print(i)
z = r[:,:,:,i]
q = reshape(z,szhtly^2,:) # in this case, there are 61^2 pixels in the stimulus.
# k = dropdims(mapslices(kurtosis,q;dims=1).-3, dims=1) # The kurtosis of any univariate normal distribution is 3. It is common to compare the kurtosis of a distribution to this value.
k = [std(q[:,j]) for j in 1:size(q,2)]
tmax = findall(x->x==max(k...),k)[1]
kmax = max(k...)
# sig = kmax>7
kernRaw = z[:,:,tmax] # raw kernel without blank normalization
kernSub = z[:,:,tmax] .- z[max_k+1,max_k+1,tmax] # kernal normalized by blank
kern = log10.(z[:,:,tmax] ./ z[max_k+1,max_k+1,tmax]) # kernal normalized by blank
replace!(kern, -Inf=>0)
# separability measure and estimate kernel
u,s,v = svd(kernRaw)
s = Diagonal(s)
lambda = s[1,1]/s[2,2]
q = s[1,1]
s = zeros(size(s))
s[1,1] = q
kest = u*s*v' # estimated kernel
# energy measure
delta = kmax / k[1] - 1
sig = delta > 0.25
# find the maxi/best condition
# bw = kern .> (max(kern...) .* 0.95)
bwmax = kernSub.== max(kernSub...)
idxmax = findall(x->x==1,bwmax)
foreach(x->if x[1]>(max_k+1) bwmax[x]=0 end,idxmax) # choose upper quadrants
# estimate ori/sf by max
zzm = sum(sum(zz.*bwmax,dims=1),dims=2)[1] / (length(idxmax)/2)
sf_max = abs(zzm)/szhtly_visangle # cyc/deg
ori_max = rad2deg(angle(zzm)) # deg
# find the best condition based on thresholding
bw = kernSub .>= quantile(kernSub[:], 0.99)
# bw = kernSub .>= (max(kernSub...) .* 0.95)
idx = findall(x->x==1,bw)
foreach(x->if x[1]>(max_k+1) bw[x]=0 end,idx) # choose upper quadrants
idx = findall(x->x==1,bw) # choose upper quadrants
# estimate ori/sf by mean
zzm = sum(sum(zz.*bw,dims=1),dims=2)[1] / length(idx)
sf_mean = abs(zzm)/szhtly_visangle # cyc/deg
ori_mean = rad2deg(angle(zzm)) # deg
## Ori tuning curve
# sf_best = max((abs.(zz).*bwmax)...)
# idxsf = findall(x->x==sf_best,abs.(zz))
# filter!(x->x[1]<=(max_k+1),idxsf)
#
# ori_idx=rad2deg.(angle.(reverse(zz[idxsf])))
# ori_curve=reverse(kern[idxsf])
sf_best = extrema(abs.(zz)[idx])
idxsf = findall(x->sf_best[1] <= x <= sf_best[2],abs.(zz))
filter!(x->x[1]<=(max_k+1),idxsf) # choose upper quadrants
filter!(x-> !((x[1]==(max_k+1)) & (x[2]<max_k+1)),idxsf) # remove 180 deg
oriCurve = DataFrame()
oriCurve.level=rad2deg.(angle.(zz[idxsf]))
oriCurve.resp=kernSub[idxsf]
sort!(oriCurve)
filter!(:resp => resp -> !any(f -> f(resp), (ismissing, isnothing, isnan, isinf)), oriCurve)
# averaged over repeated orientation
gp=groupby(oriCurve, :level)
ori_level=[];ori_resp=[];
for g in gp
push!(ori_level,mean(g.level))
push!(ori_resp,mean(g.resp))
end
# plot(ori_level,ori_resp)
## SF tuning curve
ori_best = extrema(angle.(zz)[idx])
idxori = findall(x->ori_best[1] <= x <= ori_best[2],angle.(zz))
# filter!(x->x[1]<=(max_k+1),idxori) # choose upper quadrants
sfCurve = DataFrame()
sfCurve.level = (abs.(zz[idxori]))./szhtly_visangle
sfCurve.resp = kernSub[idxori]
sort!(sfCurve)
sfCurve = sfCurve[sfCurve[:level].<=maxSF,:]
filter!(:resp => resp -> !any(f -> f(resp), (ismissing, isnothing, isnan, isinf)), sfCurve)
# averaged over repeated sf
gp=groupby(sfCurve, :level)
sf_level=[];sf_resp=[];
for g in gp
push!(sf_level,mean(g.level))
push!(sf_resp,mean(g.resp))
end
# plot(sf_level,sf_resp)
push!(taumax,tmax);push!(kstd,k);push!(kstdmax, kmax); push!(kernraw,kernRaw);push!(kernnor,kern);
push!(kernest,kest);push!(signif,sig);push!(kdelta,delta); push!(slambda,lambda);
push!(orimax,ori_max); push!(sfmax,sf_max);push!(orimean,ori_mean); push!(sfmean,sf_mean);
push!(oriLevel, ori_level);push!(oriResp,ori_resp); push!(sfLevel,sf_level); push!(sfResp,sf_resp)
# if sig == true
# heatmap(kmask,yflip=true, aspect_ratio=:equal,color=:coolwarm)
# plot([0 real(zzm)],[0 imag(zzm)],'wo-','linewidth',3,'markersize',14);
# end
end
result.signif = signif
result.taumax = taumax
result.kstdmax = kstdmax
result.kdelta = kdelta
result.slambda = slambda
result.orimax = orimax
result.sfmax = sfmax
result.orimean = orimean
result.sfmean = sfmean
result1=copy(result)
result.kstd = kstd
result.kernnor = kernnor
result.kernraw = kernraw
result.kernest = kernest
result.oriLevel = oriLevel
result.oriResp = oriResp
result.sfLevel = sfLevel
result.sfResp = sfResp
#Save results
CSV.write(joinpath(resultFolder,join([subject,"_",siteId,"_",coneType,"_tuning_result.csv"])), result1)
save(joinpath(dataExportFolder,join([subject,"_",siteId,"_",coneType,"_tuning_result.jld2"])), "result",result)
end