function [MDstar, Dstar, Mastar] = MDcode(Set1, Set2, Window, Type)
% [MDstar, Dstar, Mastar] = MDcode(Set1, Set2, Window, Type)
%
% INPUTS
% Set1 and Set2 are matrices containing the set of spike trains
% to be compared. They are built by stacking (adding rows) binned spike trains made of
% zeros and ones. The number of rows should correspond to the number
% of repetitions, and the number of lines should correspond to the
% number of (small) time bins.
% Window corresponds to the coincidence windows used to compute the inner products.
% It must have odd length and should be symmetric around its center
% if the inner product type is 2.
% Type = 1 corresponds to Van Rossum metric (Eq. 10 of Naud et al.
% (2011)) with arbitrary coincidence windows. If Type = 2 it
% corresponds to the Coincidence factor (Eq. 14 of Naud et al. (2011)) with
% arbitrary coincidence windows.
%
% OUTPUTS
% MDstar is bias corrected normalized distance (Eq. 44 of Naud et al.
% (2011)).
% Dstar is bias-corrected distance (E1. 42 of Naud et al. (2011)).
% Mastar is bias-corrected angular separation (Eq. 43 of Naud et al.
% (2011)).
%
% ALGORITHM
% The convultions are made using the fast fourier transform method
% (fftfilt.m).
%
% - R. Naud, November 2010
Nrep1 = size(Set1,1);
Nrep2 = size(Set2,1);
Set1f = zeros(size(Set1));
Set2f = zeros(size(Set2));
Ncoinc = (length(Window)-1)/2;
if Ncoinc-floor(Ncoinc) ~=0, error('Window should have odd length'), end
for i = 1:Nrep1
Set1f(i,:) = circshift( fftfilt(Window, Set1(i,:)'), -Ncoinc)';
end
for i = 1:Nrep2
Set2f(i,:) = circshift( fftfilt(Window, Set2(i,:)'), -Ncoinc)';
end
m1 = 0; m2 = 0;
if Type==1% van Rossum like
nu1 = mean(Set1f,1); nu2 = mean(Set2f,1);
for i = 1:Nrep1;
m1 = m1 + sum(Set1f(i,:).*sum(Set1f(i+1:end,:),1)*2/Nrep1/(Nrep1-1));
end
for i = 1:Nrep2
m2 = m2 + sum(Set2f(i,:).*sum(Set2f(i+1:end,:),1)*2/Nrep2/(Nrep2-1));
end
else % Type 2 coincidence factor-like
nu1 = mean(Set1f,1); nu2 = mean(Set2,1);
for i = 1:Nrep1;
m1 = m1 + sum(Set1(i,:).*sum(Set1f(i+1:end,:),1)*2/Nrep1/(Nrep1-1));
end
for i = 1:Nrep2
m2 = m2 + sum(Set2(i,:).*sum(Set2f(i+1:end,:),1)*2/Nrep2/(Nrep2-1));
end
end
MDstar = 2*sum(nu1.*nu2)/(m1 + m2);
if nargout > 1
Dstar = (m1 + m2)*(1-MDstar);
end
if nargout >2
Mastar = MDstar*(m1+m2)/2/sqrt(m1*m2);
end