from numpy import * from pylab import * def ACF_s(h,Ainf,xi,dt,order = 1,epsilon = 1e-2,Tol=1e-2): """ Finds steady-state autocorrelation function of spike trains according to quasi-renewal theory. INPUT h : scaled input xi : exp(eta)-1, defines refractoriness/adaptation dt: timestep for the linear grid defining xi (I think of seconds) order: order = 0 -> renewal theory order = 1 -> adaptation by term in \int xi(s) A(t-s) ds order = 2 -> additional correlation with term in \int \int xi(s)xi(s') C(s-s')dsds' where C is ACF function epsilon : level of precision for xi (truncation threshold) Tol: for second order ACF, tolerance in terms of the MSE between two iteration of the estimate ACF (from zero to the cutoff defined by epsilon), Tol is in the same units as Ainf**2 OUPUT Cplus: autocorrelation function on a linear grid with dt (one side, no dirac) """ lam = lam_s(h,Ainf,xi,dt,order) surv = exp(-cumsum(lam)*dt) rho = lam*surv L = len(lam) #initial values cutoff = find(abs(surv)Tol: Cest = array(Cplus) # loop to calculate correction correction = zeros(L) g2 = Cest-Cest[-1] xiC = dt*convolve(xi,g2,'full')[:L] for i in range(1,L): correction[i] = correction[i-1] - dt*(2*xi[i-1]*(xiC[-1]-xiC[i-1])) correction = .5*(correction)# factor one half comes from the factorial in the TME Ainf = Ah(h,xi,dt,Tol = 1e-3,order=2,g2s = correction)[1] lamest = lam_s(h,Ainf,xi,dt,order=2,g2s = correction) survest = exp(-cumsum(lamest)*dt) rhoest = lamest*survest cutoff = find(abs(survest)=1: xis = cumsum(xi)*dt xis = xis[-1]-xis if order ==2: xis=xis+g2s factor = sum(exp( -cumsum(lamr)*dt ) *cumsum(lamr*xis)*dt)*dt #Aest=roots([factor,1/Ar,-1])[1] Aest = (1+sqrt(1-4*factor/Ar**2))/(-2*Ar*factor) mismatch=(1/Aest -sum(exp(-cumsum(lamr*exp(Aest*xis)*dt)))*dt) while abs(mismatch) > Tol: Aest = 1/(1/Aest-mismatch/1.5) mismatch=(1/Aest -sum(exp(-cumsum(lamr*exp(Aest*xis)*dt)))*dt) if order ==0: return [Ar] if order >=1: return [Ar,Aest] def lam_s(h,Ainf,xi,dt,order = 1,g2s = 0): """ find firing intensity at steady state """ if order==0: lam = exp(h)*(1+xi) if order==1: lam = exp(h+dt*(sum(xi)-cumsum(xi))*Ainf)*(1+xi) if order ==2: lam = exp(h+dt*(sum(xi)-cumsum(xi))*Ainf + Ainf*g2s)*(1+xi) return lam def c_s(rho,cutoff,dt): """ Find autocorr """ y = rho[:cutoff]*dt y = y/sum(y) L = len(rho) Cplus = zeros(L) Cplusvec = zeros(cutoff) # forward-time iteration for acf for i in range(L): Cplus[i] = rho[i]+sum(Cplusvec*y) Cplusvec = append(array(Cplus[i-arange(min(i,cutoff-1)+1)]),zeros(max(cutoff-i-1,0))) return Cplus def EncodeFast(aparam,tauparam,h,dt,epsilon=1e-2, hbefore = 0): """ Encoding effective input h in population activity with convenient parametrization ; \eta = ln(1-e(-t/tau)) INPUTS aparam: 1-d array, amplitude of each exponential, must sum to one (enforced anyway) tauparam: 1-d array, time constant of each exponential h : 1-d array filtered input dt: time step, in seconds hbefore: assumes a constant input of hbefore, before the input h. OUTPUTS A : activity (same grid as h) R. Naud 2012.08. """ tc = -log(epsilon)*max(tauparam) indtc = int(tc/dt)+1 t=arange(0,tc,dt) t[-1]=100*max(tauparam) aparam = aparam/sum(aparam) K = len(aparam) L = len(h) xik = zeros((indtc,K)) xis = zeros(indtc) XI = zeros((indtc,K)) for i in range(K): xik[:,i] = aparam[i]*exp(-t/tauparam[i]) xis = xis + xik[:,i]*tauparam[i] XI[:,i] = tauparam[i]*xik[:,i] # initialize as if constant h = hbefore for time before t temp = Ah(hbefore,sum(xik,1),dt,Tol = 1e-3,order=1,g2s = 0) Ainf = temp[0] m = exp(-cumsum(exp(hbefore - Ainf*xis)*(1-sum(xik,1)))*dt)*Ainf*dt m = m/sum(m) A = zeros(L)+Ainf Avec = zeros(indtc)+Ainf for i in range(1,L): for k in range(K): XI[:,k] = XI[:,k]-dt*(XI[:,k]/tauparam[k] + Avec*xik[:,k]) m[1:] = m[0:-1]*exp(-dt*(1-sum(xik,1)[:-1])*exp(h[i]+sum(XI,1)[:-1])) m[0] = A[i-1]*dt A[i] = (1-sum(m))/dt Avec= concatenate((array([A[i]]),Avec[0:-1])) return A