Skip to content
This repository has been archived by the owner. It is now read-only.
Permalink
master
Switch branches/tags

Name already in use

A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
Go to file
 
 
Cannot retrieve contributors at this time
classdef gridSorter < handle
properties (Constant)
gridSorterVersion=1.00;
end
properties (SetAccess=public)
overwriteSpikeExtraction=0;
overwriteFeatureExtraction=0;
overwriteClustering=0;
overwriteMerging=0;
overwriteFitting=0;
overwritePostProcessignAnalysis=0;
sortingDir
runWithoutSaving2File=false;
fastPrinting=0;
saveFigures=1;
selectedChannelSubset=[];
dataRecordingObj %the data recording object
sortingFileNames
upSamplingFrequencySpike=60000;
localGridSize=3; %must be an odd number. If empty extracts using all channels
localGridExt=1;
detectionChunkOverlap=1; %ms
%for using a 16 bit signed integer to represent the spike shapes
detectionMaxSpikeAmp=1000;
detectionNQuantizationBits=16;
detectionGaussianityWindow=20;%ms
detectionPreSpikeWindow=2;%ms
detectionPostSpikeWindow=3;%ms
detectionSpikeTimeShiftInterval=1;%ms - the maximal interval on which the global spike minima is calculated
detectionPeakDetectionSmoothingWindow=0.3; %ms - the smoothing scale for detection of the exact spike occurance
detectionKurtosisNoiseThreshold=3;
detectionSpikeDetectionThresholdStd=5;
detectionMaxChunkSize=2*60*1000; %[ms]
detectionMinimumDetectionIntervalMs=0.5;%ms
detectionRemoveSpikesNotExtremalOnLocalGrid=true; %removes all spike that have a stronger minimum on another channel surrounding the detection channel on the grid
filterHighPassPassCutoff=250;
filterHighPassStopCutoff=200;
filterLowPassPassCutoff=2500;
filterLowPassStopCutoff=3500;
filterAttenuationInHighpass=6;
filterAttenuationInLowpass=6;
filterRippleInPassband=0.4;
filterDesign='ellip';
featuresNWaveletCoeff=30; %number of wavelet coeff to extract
featuresDimensionReductionPCA=6;
featuresReduceDimensionsWithPCA=true;
featuresSelectedWavelet='haar'; %the mother wavelet
featuresWTdecompositionLevel=4; %the level of decomposition
featuresFeatureExtractionMethod='wavelet'; %the method for extracting features
featuresConcatenateElectrodes=true; %concatenate all surrounding electrodes to 1 big voltage trace
clusteringMaxSpikesToCluster=10000; %maximum number of spikes to feed to feature extraction and clustering (not to template matching)
clusteringMethod='meanShift';%'meanShift';%'kMeans';%,'kMeans_MSEdistance';%'kMeans_merge';%'kMeans_merge'; %'kMeans_sil'/'kMeans_merge'/'GMMEM'
clusteringMergingMethod='projectionMeanStd';%'projectionMeanStd';%'MSEdistance'
clusteringMaxIter=1000; %max itteration for clustering algorithm
clusteringNReplicates=50; %number of obj.clusteringNReplicates in clustering algorithm
clusteringInitialClusterCentersMethod='cluster'; %method for initial conditions in k-means algorithm
clusteringMaxPointsInSilhoutte=1000; %the maximal number of data points per cluster used for silloute quality estimation
clusteringMergeThreshold=0.18; %threshold for merging clusters %the higher the threshold the less cluster will merge
clusteringSTDMergeFac=2; %the fraction of a gaussian to check from the crossing point on both sides (the higher the more clusters separate into groups)
clusteringRunSecondMerging=0;
clusteringMaxClusters=12; %the maximum number of clusters for a specific channels
clusteringMinimumChannelRate=0.01; %[Hz]
clusteringMinNSpikesCluster=10;
clusteringMinSpikesTotal=50; %do not attempt to cluster channels with less spikes
clusteringPlotProjection=1;
clusteringPlotClassification=1;
mergingThreshold=0.1;
mergingRecalculateTemplates=true;
mergingNStdSpikeDetection=4;
mergingAllignSpikeBeforeAveraging=true;
mergingAllignWaveShapes=true;
mergingTestInitialTemplateMerging=false;
mergingPlotStatistics=true;
mergingNStdNoiseDetection=6;
mergingPreSpike4NoiseDetection=0.2;
mergingPostSpike4NoiseDetection=0.2;
mergingNoiseThreshold=0.2;
mergingMaxSpikeShift=1; %ms
fittingTemplateMethod=1;
fittingMaxLag=1.6;
fittingLagIntervalSamples=3;
postMaxSpikes2Present=1000;
postPlotAllAvgSpikeTemplates=true;
postExtractFilteredWaveformsFromSpikeTimes=true;
postExtractRawLongWaveformsFromSpikeTimes=true;
postFilteredSNRStartEnd=[-1 1];
postRawSNRStartEnd=[5 40];
postPlotFilteredWaveforms=true;
postPlotRawLongWaveforms=true;
postPlotSpikeReliability=true;
postPreFilteredWindow=2;
postTotalFilteredWindow=5;
postPreRawWindow=10;
postTotalRawWindow=90;
end
properties (SetAccess=protected)
nCh
chPar
arrayExt
detectionInt2uV
filterObj
end
properties (SetObservable, AbortSet = true, SetAccess=public)
end
properties (Hidden, SetAccess=protected)
end
methods
%class constractor
function obj=gridSorter(dataRecordingObj,varargin)
%addlistener(obj,'visualFieldBackgroundLuminance','PostSet',@obj.initializeBackground); %add a listener to visualFieldBackgroundLuminance, after its changed its size is updated in the changedDataEvent method
%Collects all options - if properties are given as a 'propertyName',propertyValue series
for i=1:2:length(varargin)
eval(['obj.' varargin{i} '=' 'varargin{i+1};'])
end
if nargin==1
obj.dataRecordingObj=dataRecordingObj;
else
disp('Data recording object not enter. In later versions a GUI will be imlemented for recording selection');
end
[~, name, ~] = fileparts(obj.dataRecordingObj.recordingName{1});
obj.sortingDir=[obj.dataRecordingObj.recordingDir filesep name '_spikeSort'];
%make directory in recording folder with spike sorting data
if ~exist(obj.sortingDir,'dir')
mkdir(obj.sortingDir);
disp(['Creating spike sorting folder: ' obj.sortingDir]);
end
obj=obj.calculateChParameters;
end
function [obj,t,ic,avgWaveform]=runSorting(obj)
%default variables
%initiate variables
t=[];ic=[];avgWaveform=[];
obj=obj.findSortingFiles;
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%% spike Detection %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
tic;
if all(obj.sortingFileNames.spikeDetectionExist) %check for the existence of spike shapes
disp('Sorting will be preformed on previously detected waveforms');
else
obj=obj.spikeDetectionNSK;
end
toc;
%save meta-data
[props]=getProperties(obj);
save([obj.sortingDir filesep 'metaData.mat'],'props');
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%% feature extraction %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
tic;
if all(obj.sortingFileNames.featureExtractionExist) %check for the existence of spike shapes
disp('Sorting will be preformed on previously extracted features');
else
obj=obj.spikeFeatureExtractionNSK;
end
toc;
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% clustering %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
tic;
if all(obj.sortingFileNames.clusteringExist) %check for the existence of spike shapes
disp('Sorting will be preformed on previously extracted clusters');
else
obj=obj.spikeClusteringNSK;
end
toc;
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% Merging duplicate neurons %%%%%%%%%%%%%%%%%%%%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
tic;
if obj.sortingFileNames.mergedAvgWaveformExist %check for the existence of spike shapes
disp('Sorting will be preformed on previously merged clusters');
else
obj=obj.spikeMergingClustersNSK;
end
toc;
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% Fitting duplicate neurons %%%%%%%%%%%%%%%%%%%%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
tic;
if obj.sortingFileNames.fittingExist %check for the existence of spike shapes
disp('No fitting performed!!!');
else
obj=obj.spikeFittingNSK;
end
toc;
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%%%%%%%%%%%%% General final plots and access sorting quality %%%%%%%%%%%%%%%%%%%%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
tic;
if all(obj.sortingFileNames.postProcessignAnalysisExist) %check for the existence of spike shapes
disp('No post analysis performed!!!');
else
obj=obj.spikePostProcessingNSK;
end
toc;
end
function obj=calculateChParameters(obj)
%determine channels
if isempty(obj.selectedChannelSubset)
obj.selectedChannelSubset=obj.dataRecordingObj.channelNumbers; %take all channel in the recording
end
obj.nCh=numel(obj.selectedChannelSubset);
obj.chPar.s2r=obj.selectedChannelSubset; %transformation between the serial channel number and the real ch number
obj.chPar.r2s(obj.selectedChannelSubset)=1:obj.nCh; %transformation between the real channel number and the serial ch number
%get channel layout
obj.chPar.En=obj.dataRecordingObj.chLayoutNumbers;
%limit layout only to the serial channels selected by user
obj.chPar.rEn=obj.chPar.En;
obj.chPar.rEn(~isnan(obj.chPar.En))=obj.chPar.r2s(obj.chPar.En(~isnan(obj.chPar.En)));
obj.chPar.rEn(obj.chPar.rEn==0)=NaN;
[nRowsTmp,nColsTmp]=size(obj.chPar.En);
localGridSizeExt=obj.localGridSize+obj.localGridExt*2;
%initiate arrays
obj.chPar.surChExt=cell(1,obj.nCh);obj.chPar.pValidSurChExt=cell(1,obj.nCh);obj.chPar.surChExtVec=cell(1,obj.nCh);obj.chPar.pCenterCh=zeros(1,obj.nCh);obj.chPar.nValidChExt=zeros(1,obj.nCh);obj.chPar.pSurCh=cell(1,obj.nCh);obj.chPar.pSurChOverlap=cell(1,obj.nCh);
if ~isempty(localGridSizeExt)
obj.arrayExt=(obj.localGridSize-1)/2;
overheadGridExtension=(localGridSizeExt-1)/2;
EnExt=NaN(nRowsTmp+overheadGridExtension*2,nColsTmp+overheadGridExtension*2);
EnExt(1+overheadGridExtension:end-overheadGridExtension,1+overheadGridExtension:end-overheadGridExtension)=obj.chPar.rEn;
for i=1:obj.nCh
[x,y]=find(EnExt==i);
%find the surrounding channels on which feature extraction will be performed
surCh=EnExt(x-obj.arrayExt:x+obj.arrayExt,y-obj.arrayExt:y+obj.arrayExt);
pValidSurCh=find(~isnan(surCh)); %do not remove find
%find the channels that are overhead step from the central channel - these are the channels who's waveforms should be checked for merging
surChOverlap=EnExt(x-obj.localGridExt:x+obj.localGridExt,y-obj.localGridExt:y+obj.localGridExt);
surChOverlap(obj.localGridExt+1,obj.localGridExt+1)=NaN;
pValidSurChOverlap=find(~isnan(surChOverlap)); %do not remove find
%find the extended channels for merging of the same neurons detected on nearby channels
obj.chPar.surChExt{i}=EnExt(x-overheadGridExtension:x+overheadGridExtension,y-overheadGridExtension:y+overheadGridExtension);
obj.chPar.pValidSurChExt{i}=find(~isnan(obj.chPar.surChExt{i})); %do not remove find
obj.chPar.nValidChExt(i)=numel(obj.chPar.pValidSurChExt{i});
obj.chPar.surChExtVec{i}=obj.chPar.surChExt{i}(obj.chPar.pValidSurChExt{i}(:))';
obj.chPar.pCenterCh(i)=find(obj.chPar.surChExtVec{i}==i); %the position of the central channel in surChExtVec
[~,obj.chPar.pSurCh{i}]=intersect(obj.chPar.surChExtVec{i},surCh(pValidSurCh(:)));
[~,obj.chPar.pSurChOverlap{i}]=intersect(obj.chPar.surChExtVec{i},surChOverlap(pValidSurChOverlap(:)));
end
end
%map channel intersections
for i=1:obj.nCh
for j=obj.chPar.surChExtVec{i}(obj.chPar.pSurChOverlap{i}) %the trivial case : (i,i) is also included (can be remove in not required)
[obj.chPar.sharedChNames{i}{j},obj.chPar.pSharedCh1{i}{j},obj.chPar.pSharedCh2{i}{j}]=intersect(obj.chPar.surChExtVec{i},obj.chPar.surChExtVec{j});
end
end
end
function obj=findSortingFiles(obj)
%check conditions for recalculating the different stages of spike sorting. Notice that for the 3 first procedures only, calcululation of a subset of uncalculated channels is possiblee
for i=1:obj.nCh
obj.sortingFileNames.spikeDetectionFile{i}=[obj.sortingDir filesep 'ch_' num2str(obj.chPar.s2r(i)) '_spikeDetection.mat'];
obj.sortingFileNames.spikeDetectionExist(i)=exist(obj.sortingFileNames.spikeDetectionFile{i},'file');
obj.sortingFileNames.featureExtractionFile{i}=[obj.sortingDir filesep 'ch_' num2str(obj.chPar.s2r(i)) '_featureExtraction.mat'];
obj.sortingFileNames.featureExtractionExist(i)=exist(obj.sortingFileNames.featureExtractionFile{i},'file');
obj.sortingFileNames.clusteringFile{i}=[obj.sortingDir filesep 'ch_' num2str(obj.chPar.s2r(i)) '_clustering.mat'];
obj.sortingFileNames.clusteringExist(i)=exist(obj.sortingFileNames.clusteringFile{i},'file');
end
obj.sortingFileNames.avgWaveformFile=[obj.sortingDir filesep 'AllClusteredWaveforms.mat'];
obj.sortingFileNames.mergedAvgWaveformFile=[obj.sortingDir filesep 'AllMergedWaveforms.mat'];
obj.sortingFileNames.mergedAvgWaveformExist=exist(obj.sortingFileNames.mergedAvgWaveformFile,'file');
obj.sortingFileNames.fittingFile=[obj.sortingDir filesep 'spikeSorting.mat'];
obj.sortingFileNames.fittingExist=exist(obj.sortingFileNames.fittingFile,'file');
obj.sortingFileNames.postProcessignAnalysisFile=[obj.sortingDir filesep 'postProcessignAnalysis.mat'];
obj.sortingFileNames.postProcessignAnalysisExist=exist(obj.sortingFileNames.postProcessignAnalysisFile,'file');
if obj.sortingFileNames.postProcessignAnalysisExist
tmp=load(obj.sortingFileNames.postProcessignAnalysisFile,'postProcessignAnalysisExist');
obj.sortingFileNames.postProcessignAnalysisExist=tmp.postProcessignAnalysisExist;
end
%the overwriting of files needed, treat all files as non existing
if obj.overwriteSpikeExtraction,obj.sortingFileNames.spikeDetectionExist(:)=0;end;
if obj.overwriteFeatureExtraction,obj.sortingFileNames.featureExtractionExist(:)=0;end;
if obj.overwriteClustering,obj.sortingFileNames.clusteringExist(:)=0;end;
if obj.overwriteMerging,obj.sortingFileNames.mergedAvgWaveformExist=0;end;
if obj.overwriteFitting,obj.sortingFileNames.fittingExist=0;end;
if obj.overwritePostProcessignAnalysis,obj.sortingFileNames.postProcessignAnalysisExist=0;end;
end
function obj=getHighpassFilter(obj)
obj.filterObj=filterData(obj.dataRecordingObj.samplingFrequency);
obj.filterObj.highPassPassCutoff=obj.filterHighPassPassCutoff;
obj.filterObj.highPassStopCutoff=obj.filterHighPassStopCutoff;
obj.filterObj.lowPassPassCutoff=obj.filterLowPassPassCutoff;
obj.filterObj.lowPassStopCutoff=obj.filterLowPassStopCutoff;
obj.filterObj.attenuationInHighpass=obj.filterAttenuationInHighpass;
obj.filterObj.attenuationInLowpass=obj.filterAttenuationInLowpass;
obj.filterObj.filterDesign=obj.filterDesign;
obj.filterRippleInPassband=obj.filterRippleInPassband;
%obj.filterObj.highPassCutoff=obj.filterHighPassPassCutoff;
%obj.filterObj.lowPassCutoff=obj.filterLowPassPassCutoff;
%obj.filterObj.filterOrder=8;
obj.filterObj=obj.filterObj.designBandPass;
end
function obj=spikeDetectionNSK(obj)
%Todo:
%Add detection of multiple spikes
%determine quantization
obj.detectionInt2uV=obj.detectionMaxSpikeAmp/2^(obj.detectionNQuantizationBits-1);
obj=obj.getHighpassFilter;
%start spike detection
fprintf('\nRunning spike detection on %s...',obj.dataRecordingObj.dataFileNames{1});
upSamplingFactor=obj.upSamplingFrequencySpike/obj.dataRecordingObj.samplingFrequency;
if upSamplingFactor~=round(upSamplingFactor) %check that upsampling factor is an integer
upSamplingFactor=round(upSamplingFactor);
obj.upSamplingFrequencySpike=upSamplingFactor*obj.dataRecordingObj.samplingFrequency;
disp(['upSampling factor was not an integer and was rounded to: ' num2str(upSamplingFactor)]);
end
gaussianityWindow=obj.detectionGaussianityWindow/1000*obj.dataRecordingObj.samplingFrequency;
testSamples=gaussianityWindow*1000;
preSpikeSamples=obj.detectionPreSpikeWindow/1000*obj.dataRecordingObj.samplingFrequency; %must be > spikePeakInterval
postSpikeSamples=obj.detectionPostSpikeWindow/1000*obj.dataRecordingObj.samplingFrequency; %must be > spikePeakInterval
spikeTimeShiftIntervalSamples=obj.detectionSpikeTimeShiftInterval/1000*obj.dataRecordingObj.samplingFrequency;
postSpikeSamplesInitial=postSpikeSamples+spikeTimeShiftIntervalSamples;
timeVec=-preSpikeSamples:postSpikeSamplesInitial;
intrpTimeVec=timeVec(1):(1/upSamplingFactor):timeVec(end);
pZeroTimeVec=find(intrpTimeVec>=0,1,'first');
preSpikeSamplesIntrp=preSpikeSamples*upSamplingFactor; %must be > spikePeakInterval
postSpikeSamplesIntrp=postSpikeSamples*upSamplingFactor; %must be > spikePeakInterval
spikeTimeShiftIntervalIntrp=spikeTimeShiftIntervalSamples*upSamplingFactor; %must be > spikePeakInterval
minimumDetectionIntervalSamplesIntrp=obj.detectionMinimumDetectionIntervalMs/1000*obj.dataRecordingObj.samplingFrequency*upSamplingFactor;
peakDetectionSmoothingSamples=round(obj.detectionPeakDetectionSmoothingWindow/1000*obj.dataRecordingObj.samplingFrequency*upSamplingFactor);
peakSmoothingKernel=fspecial('gaussian', [3*peakDetectionSmoothingSamples 1] ,peakDetectionSmoothingSamples);
%determine the chunck size
if obj.detectionMaxChunkSize>obj.dataRecordingObj.recordingDuration_ms
startTimes=0;
endTimes=obj.dataRecordingObj.recordingDuration_ms;
else
startTimes=0:obj.detectionMaxChunkSize:obj.dataRecordingObj.recordingDuration_ms;
endTimes=[startTimes(2:end)-obj.detectionChunkOverlap obj.dataRecordingObj.recordingDuration_ms];
end
nChunks=numel(startTimes);
obj.nCh=numel(obj.chPar.s2r);
matFileObj=cell(1,obj.nCh);
if obj.runWithoutSaving2File %if saving data is not required
obj.runWithoutSaving2File=true;
spikeShapesAll=cell(obj.nCh,nChunks);
spikeTimesAll=cell(obj.nCh,nChunks);
else
obj.runWithoutSaving2File=false;
for i=find(obj.sortingFileNames.spikeDetectionExist==0)
matFileObj{i} = matfile(obj.sortingFileNames.spikeDetectionFile{i},'Writable',true);
matFileObj{i}.spikeShapes=zeros(preSpikeSamplesIntrp+postSpikeSamplesIntrp,0,obj.chPar.nValidChExt(i),'int16');
end
end
%initiate arrays
Th=zeros(obj.nCh,nChunks);nCumSpikes=zeros(1,obj.nCh);
fprintf('\nExtracting spikes from chunks (total %d): ',nChunks);
for j=1:nChunks
fprintf('%d ',j);
%get data
MAll=squeeze(obj.filterObj.getFilteredData(obj.dataRecordingObj.getData(obj.chPar.s2r(1:obj.nCh),startTimes(j),endTimes(j)-startTimes(j))))';
nSamples=size(MAll,1);
for i=find(obj.sortingFileNames.spikeDetectionExist==0) %go over all channels that require rewriting
%get local data
Mlong=MAll(:,obj.chPar.surChExtVec{i});
%estimate channel noise
tmpData=buffer(Mlong(1:min(testSamples,nSamples),obj.chPar.pCenterCh(i)),gaussianityWindow,gaussianityWindow/2);
noiseSamples=tmpData(:,kurtosis(tmpData,0)<obj.detectionKurtosisNoiseThreshold);
noiseStd=std(noiseSamples(:));
noiseMean=mean(noiseSamples(:));
Th(i,j)=noiseMean-obj.detectionSpikeDetectionThresholdStd*noiseStd;
%find thershold crossings and extract spike windows
thresholdCrossings=find(Mlong(1:end-1,obj.chPar.pCenterCh(i))>Th(i,j) & Mlong(2:end,obj.chPar.pCenterCh(i))<Th(i,j));
thresholdCrossings=thresholdCrossings(thresholdCrossings>preSpikeSamples & thresholdCrossings<nSamples-postSpikeSamplesInitial);
%plot(Mlong(:,obj.chPar.pCenterCh(i)));hold on;line([0 size(Mlong,1)],[Th(i,j) Th(i,j)],'color','r');plot(thresholdCrossings+1,Mlong(thresholdCrossings+1,obj.chPar.pCenterCh(i)),'og');
%extract upsample and allign spikes
startSamplesInM=[];startSamplesInIdx=[];
if ~isempty(thresholdCrossings)
%upsample
startSamplesInM(1,1,:)=(0:nSamples:(nSamples*(obj.chPar.nValidChExt(i)-1)));
idx=bsxfun(@plus,bsxfun(@plus,thresholdCrossings',(-preSpikeSamples:postSpikeSamplesInitial)'), startSamplesInM );
M=Mlong(idx);
%upsample data
M = interp1(timeVec, M, intrpTimeVec, 'spline');
nSamplesShort=size(M,1);
%allign spike windows to spike extrema
Msmooth = convn(M((pZeroTimeVec+1):(pZeroTimeVec+spikeTimeShiftIntervalIntrp),:,obj.chPar.pCenterCh(i)), peakSmoothingKernel, 'same');
[spikeAmp,shift]=min(Msmooth);
spikeTimesTmp=startTimes(j)+(thresholdCrossings'+shift/upSamplingFactor)/obj.dataRecordingObj.samplingFrequency*1000; %[ms]
nSamplesPerCh=numel(spikeTimesTmp)*nSamplesShort;
startSamplesInIdx(1,1,:)=(0:nSamplesPerCh:nSamplesPerCh*obj.chPar.nValidChExt(i)-1);
idx=bsxfun(@plus , bsxfun(@plus,pZeroTimeVec+shift+(0:nSamplesShort:(nSamplesPerCh-1)),(-preSpikeSamplesIntrp:(postSpikeSamplesIntrp-1))') , startSamplesInIdx);
M=M(idx);
%figure;plotShifted(reshape(permute(M,[1 3 2]),[size(M,1)*size(M,3) size(M,2)]),'verticalShift',30);line([(obj.chPar.pCenterCh(i)-1)*size(M,1) obj.chPar.pCenterCh(i)*size(M,1)],[0 0],'color','g','lineWidth',3);
%ii=2;h=axes;activityTracePhysicalSpacePlot(h,obj.chPar.surChExtVec{i},squeeze(M(:,ii,:))',obj.chPar.rEn);
if obj.detectionRemoveSpikesNotExtremalOnLocalGrid
%check for a minimum (negative spike peak) over all channels to detect the channel with the strongest amplitude for each spike
[maxV,maxP]=min( min( M((preSpikeSamplesIntrp-minimumDetectionIntervalSamplesIntrp):(preSpikeSamplesIntrp+minimumDetectionIntervalSamplesIntrp),:,:) ,[],1) ,[],3);
p=(maxP==obj.chPar.pCenterCh(i));
M=M(:,p,:);
spikeTimesTmp=spikeTimesTmp(p'); %the p' is important to create a 1-0 empty matrix (not 0-1) which cell2mat can handle
end
spikeTimesAll{i,j}=spikeTimesTmp;
if ~obj.runWithoutSaving2File
tmpSpikeCount=numel(spikeTimesTmp);
if numel(spikeTimesTmp)>0
nCumSpikes(i)=nCumSpikes(i)+numel(spikeTimesTmp);
matFileObj{i}.spikeShapes(:,(nCumSpikes(i)-tmpSpikeCount+1):nCumSpikes(i),:)=int16(M./obj.detectionInt2uV);
end
else
spikeShapesAll{i,j}=int16(M./obj.detectionInt2uV);
end
else
spikeTimesAll{i,j}=[];
if obj.runWithoutSaving2File
spikeShapesAll{i,j}=[];
end
end
end
end
clear MAll Mlong M;
if ~obj.runWithoutSaving2File %write files to disk
for i=find(obj.sortingFileNames.spikeDetectionExist==0)
matFileObj{i}.Th=Th(i,:);
matFileObj{i}.spikeTimes=cell2mat(spikeTimesAll(i,:));
matFileObj{i}.preSpikeSamplesIntrp=preSpikeSamplesIntrp;
matFileObj{i}.postSpikeSamplesIntrp=postSpikeSamplesIntrp;
matFileObj{i}.upSamplingFrequencySpike=obj.upSamplingFrequencySpike;
matFileObj{i}.minimumDetectionIntervalSamplesIntrp=minimumDetectionIntervalSamplesIntrp;
matFileObj{i}.detectionInt2uV=obj.detectionInt2uV;
end
else %keep files in memory
obj.spikeDetectionData.spikeShapes=cell(1,obj.nCh);
obj.spikeDetectionData.spikeTimes=cell(1,obj.nCh);
for i=find(obj.sortingFileNames.spikeDetectionExist>0)
obj.spikeDetectionData.spikeShapes{i}=cell2mat(spikeShapesAll(i,:));
obj.spikeDetectionData.spikeTimes{i}=cell2mat(spikeTimesAll(i,:));
end
obj.spikeDetectionData.Th=Th;
obj.spikeDetectionData.preSpikeSamplesIntrp=preSpikeSamplesIntrp;
obj.spikeDetectionData.postSpikeSamplesIntrp=postSpikeSamplesIntrp;
obj.spikeDetectionData.minimumDetectionIntervalSamplesIntrp=minimumDetectionIntervalSamplesIntrp;
obj.spikeDetectionData.upSamplingFrequencySpike=obj.upSamplingFrequencySpike;
obj.spikeDetectionData.detectionInt2uV=obj.detectionInt2uV;
end
end
function [obj,spikeFeaturesAll]=spikeFeatureExtractionNSK(obj)
if isempty(obj.sortingDir)
obj.runWithoutSaving2File=true;
spikeFeaturesAll=cell(1,obj.nCh);
else
obj.runWithoutSaving2File=false;
end
fprintf('\nExtracting spike features from channels (total %d): ',obj.nCh);
for i=find(obj.sortingFileNames.featureExtractionExist==0)
spikeFeatures=[];
fprintf('%d ',i);
if ~exist(obj.sortingFileNames.spikeDetectionFile{i},'file')
warning(['No spike detection file was found for Channel ' num2str(i) '. Feature extraction not performed!']);
continue;
else
load(obj.sortingFileNames.spikeDetectionFile{i});
end
if ~isempty(spikeShapes)
%choose a random subset of the spikes for clustering
nSurroundingChannels=numel(obj.chPar.pSurCh{i});
[nSamples,nSpikes,nLocalCh]=size(spikeShapes);
nSpikes4Clustering=min(obj.clusteringMaxSpikesToCluster,nSpikes);
sd=[];
switch obj.featuresFeatureExtractionMethod
case 'wavelet'
spikeShapes=double(spikeShapes) .* detectionInt2uV;
if obj.featuresConcatenateElectrodes==1 %all waveforms are ordered channel by channel
spikeFeatures=wavedec(spikeShapes(:,1,obj.chPar.pSurCh{i}),obj.featuresWTdecompositionLevel,obj.featuresSelectedWavelet);
nCoeffs=numel(spikeFeatures);
spikeFeatures=zeros(nSpikes4Clustering,nCoeffs);
for j=1:nSpikes4Clustering
spikeFeatures(j,:)=wavedec(spikeShapes(:,j,obj.chPar.pSurCh{i}),obj.featuresWTdecompositionLevel,obj.featuresSelectedWavelet); %'haar','coif1'
end
spikeFeatures2=zeros(nSpikes4Clustering,nCoeffs);
for j=1:nSpikes4Clustering
spikeFeatures2(j,:)=wavedec(permute(spikeShapes(:,j,obj.chPar.pSurCh{i}),[3 2 1]),obj.featuresWTdecompositionLevel,obj.featuresSelectedWavelet); %'haar','coif1'
end
spikeFeatures=[spikeFeatures spikeFeatures2];
else
spikeFeatures=wavedec(spikeShapes(:,1,obj.chPar.pSurCh{i}(1)),obj.featuresWTdecompositionLevel,obj.featuresSelectedWavelet);
nCoeffs=numel(spikeFeatures);
spikeFeatures=zeros(nCoeffs,nSpikes4Clustering,nSurroundingChannels);
for j=1:nSpikes4Clustering
for k=1:nSurroundingChannels
spikeFeatures(:,j,k)=wavedec(spikeShapes(:,j,obj.chPar.pSurCh{i}(k)),obj.featuresWTdecompositionLevel,obj.featuresSelectedWavelet); %'haar','coif1'
end
end
spikeFeatures=reshape(permute(spikeFeatures,[1 3 2]),[size(spikeFeatures,1)*size(spikeFeatures,3) size(spikeFeatures,2)])';
nCoeffs=nCoeffs*nSurroundingChannels;
end
for j=1:(nCoeffs*2) % KS test for coefficient selection
thr_dist = std(spikeFeatures(:,j)) * 3;
thr_dist_min = mean(spikeFeatures(:,j)) - thr_dist;
thr_dist_max = mean(spikeFeatures(:,j)) + thr_dist;
aux = spikeFeatures(spikeFeatures(:,j)>thr_dist_min & spikeFeatures(:,j)<thr_dist_max,j);
if length(aux) > 10;
[ksstat]=test_ks(aux);
sd(j)=ksstat;
else
sd(j)=0;
end
end
[~,tmp1]=sort(sd(1:nCoeffs),'descend');
[~,tmp2]=sort(sd(nCoeffs+1:end),'descend');
spikeFeatures=spikeFeatures(:,[tmp1(1:obj.featuresNWaveletCoeff/2) nCoeffs+tmp2(1:obj.featuresNWaveletCoeff/2)]);
if obj.featuresReduceDimensionsWithPCA
[PCAsimMat,spikeFeatures] = princomp(spikeFeatures); %run PCA for visualization purposes
spikeFeatures=spikeFeatures(:,1:obj.featuresDimensionReductionPCA);
end
case 'PCA' %this option was tested and gives worse results than wavelets
spikeShapes=double(spikeShapes(:,:,obj.chPar.pSurCh{i})) .* detectionInt2uV;
[~,spikeFeatures] = princomp(reshape(permute(spikeShapes,[1 3 2]),[nSamples*numel(obj.chPar.pSurCh{i}) nSpikes]));
spikeFeatures=spikeFeatures(1:obj.featuresDimensionReductionPCA,:)';
end
end
if ~obj.runWithoutSaving2File
save(obj.sortingFileNames.featureExtractionFile{i},'spikeFeatures','-v7.3');
else
spikeFeaturesAll{i}=spikeFeatures;
end
end
end
function [idx,initIdx,nClusters,avgSpikeWaveforms,stdSpikeWaveforms]=spikeClusteringNSK(obj)
avgClusteredWaveforms=cell(1,obj.nCh);
stdClusteredWaveforms=cell(1,obj.nCh);
if isempty(obj.sortingDir)
obj.runWithoutSaving2File=true;
idxAll=cell(1,obj.nCh);
initIdxAll=cell(1,obj.nCh);
nClustersAll=cell(1,obj.nCh);
else
obj.runWithoutSaving2File=false;
end
fprintf('\nClustering on channel (total %d): ',obj.nCh);
for i=find(obj.sortingFileNames.clusteringExist==0) %go over all channels in the recording
fprintf('%d ',i);
MaxClustersTmp=obj.clusteringMaxClusters;
if ~exist(obj.sortingFileNames.featureExtractionFile{i},'file')
warning(['No feature extraction file was found for Channel ' num2str(i) '. Clustering not performed!']);
continue;
else
load(obj.sortingFileNames.featureExtractionFile{i});
load(obj.sortingFileNames.spikeDetectionFile{i},'spikeTimes');
end
[nSpikes,nFeatures]=size(spikeFeatures);
if nSpikes >= obj.clusteringMinSpikesTotal && nSpikes >= (spikeTimes(end)-spikeTimes(1))/1000*obj.clusteringMinimumChannelRate
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%%%%%%%%%%%%%%% Clustering %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
switch obj.clusteringMethod
case 'kMeans'
% set options to k-Means
opts = statset('obj.clusteringMaxIter',obj.clusteringMaxIter);
try
[initIdx] = kmeans(spikeFeatures,obj.clusteringMaxClusters,'options',opts,...
'emptyaction','singleton','distance','city','onlinephase','on','obj.clusteringNReplicates',obj.clusteringNReplicates,'start',obj.clusteringInitialClusterCentersMethod);
catch %if the number of samples is too low, kmeans gives an error -> try kmeans with a lower number of clusters
MaxClustersTmp=round(obj.clusteringMaxClusters/2);
[initIdx] = kmeans(spikeFeatures,MaxClustersTmp,'options',opts,...
'emptyaction','singleton','distance','city','onlinephase','on','obj.clusteringNReplicates',obj.clusteringNReplicates,'start',obj.clusteringInitialClusterCentersMethod);
end
case 'meanShift'
initIdx=zeros(nSpikes,1);
out=MSAMSClustering(spikeFeatures');
for j=1:numel(out)
initIdx(out{j})=j;
end
end
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%%%%%%%%%%%%%%% Merging %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
switch obj.clusteringMergingMethod
case 'MSEdistance'
%calculate templates
load(obj.sortingFileNames.spikeDetectionFile{i},'spikeShapes');
avgSpikeWaveforms=zeros(nSpikeSamples,max(1,MaxClustersTmp),nSurroundingChannels);
for j=1:MaxClustersTmp
avgSpikeWaveforms(:,j,:)=median(spikes4Clustering(:,initIdx==j,:),2);
end
[gc,Merge]=obj.SpikeTempDiffMerging(permute(spikes4Clustering,[2 1 3]),initIdx,permute(avgSpikeWaveforms,[3 1 2]));
if obj.saveFigures
f1=figure('Position',[50 50 1400 900]);
set(f1,'PaperPositionMode','auto');
if obj.fastPrinting
imwrite(frame2im(getframe(f1)),[obj.sortingDir filesep 'Ch_' num2str(obj.chPar.s2r(i)) 'projectionTest.jpeg'],'Quality',90);
else
print([obj.sortingDir filesep 'Ch_' num2str(obj.chPar.s2r(i)) 'projectionTest'],'-djpeg','-r300');
end
close(f1);
end
case 'projectionMeanStd'
[gc,f1]=obj.projectionMerge(spikeFeatures,initIdx,'obj.clusteringMinNSpikesCluster',obj.clusteringMinNSpikesCluster,'obj.clusteringSTDMergeFac',obj.clusteringSTDMergeFac,'obj.clusteringMergeThreshold',obj.clusteringMergeThreshold,'obj.clusteringPlotProjection',obj.clusteringPlotProjection);
if obj.saveFigures && ~isempty(f1);
set(f1,'PaperPositionMode','auto');
if obj.fastPrinting
imwrite(frame2im(getframe(f1)),[obj.sortingDir filesep 'Ch_' num2str(obj.chPar.s2r(i)) 'projectionTest.jpeg'],'Quality',90);
else
print([obj.sortingDir filesep 'Ch_' num2str(obj.chPar.s2r(i)) 'projectionTest'],'-djpeg','-r300');
end
close(f1);
end
if obj.clusteringRunSecondMerging
uniqueClusters=unique(gc);
nClusters=numel(uniqueClusters);
idx=zeros(nSpikes4Clustering,1);
for k=1:nClusters
p=find(gc==uniqueClusters(k));
for j=1:numel(p)
idx(initIdx==p(j))=k;
end
end
MaxClustersTmp=nClusters;
initIdx=idx;
[gc,f1]=obj.projectionMerge(spikeFeatures,initIdx,'obj.clusteringMinNSpikesCluster',obj.clusteringMinNSpikesCluster,'obj.clusteringSTDMergeFac',obj.clusteringSTDMergeFac,'obj.clusteringMergeThreshold',obj.clusteringMergeThreshold,'obj.clusteringPlotProjection',obj.clusteringPlotProjection);
if obj.saveFigures
set(f1,'PaperPositionMode','auto');
if obj.fastPrinting
imwrite(frame2im(getframe(f1)),[obj.sortingDir '\Ch_' num2str(obj.chPar.s2r(i)) 'projectionTest.jpeg'],'Quality',90);
else
print([obj.sortingDir '\Ch_' num2str(obj.chPar.s2r(i)) 'projectionTest'],'-djpeg','-r300');
end
close(f1);
end
end
end
%reclassify clusters
uniqueClusters=unique(gc);
nClusters=numel(uniqueClusters);
idx=zeros(nSpikes,1);
for k=1:nClusters
p=find(gc==uniqueClusters(k));
for j=1:numel(p)
idx(initIdx==p(j))=k;
end
end
%calculate spikeShape statistics
load(obj.sortingFileNames.spikeDetectionFile{i},'spikeShapes','detectionInt2uV');
spikeShapes=double(spikeShapes) .* detectionInt2uV;
[nSpikeSamples,nSpikes,nSurroundingChannels]=size(spikeShapes);
if nSpikes>obj.clusteringMaxSpikesToCluster
spikeShapes=spikeShapes(:,1:obj.clusteringMaxSpikesToCluster,:);
end
avgSpikeWaveforms=zeros(nSpikeSamples,max(1,nClusters),nSurroundingChannels);
stdSpikeWaveforms=zeros(nSpikeSamples,max(1,nClusters),nSurroundingChannels);
for j=1:nClusters
pCluster=idx==j;
avgSpikeWaveforms(:,j,:)=median(spikeShapes(:,pCluster,:),2);
stdSpikeWaveforms(:,j,:)=1.4826*median(abs(spikeShapes(:,pCluster,:)- bsxfun(@times,avgSpikeWaveforms(:,j,:),ones(1,sum(pCluster),1)) ),2);
nSpk(j)=numel(pCluster);
end
else
fprintf('X '); %to note than no neurons were detected on this electrode
idx=ones(nSpikes,1);
initIdx=ones(nSpikes,1);
avgSpikeWaveforms=[];
stdSpikeWaveforms=[];
nClusters=0;
nSpk=0;
end
avgClusteredWaveforms{i}=avgSpikeWaveforms;
stdClusteredWaveforms{i}=stdSpikeWaveforms;
nAvgSpk{i}=nSpk;
if ~obj.runWithoutSaving2File
save(obj.sortingFileNames.clusteringFile{i},'idx','initIdx','nClusters','avgSpikeWaveforms','stdSpikeWaveforms');
else
idxAll{i}=idx;
initIdxAll{i}=initIdx;
nClustersAll{i}=nClusters;
end
if obj.clusteringPlotClassification && nClusters>0
cmap=lines;
f2=figure('Position',[100 100 1200 800],'color','w');
PCAfeaturesSpikeShapePlot(spikeFeatures,spikeShapes,obj.upSamplingFrequencySpike,initIdx,idx,obj.chPar.En,obj.chPar.s2r(obj.chPar.surChExtVec{i}),'hFigure',f2,'cmap',cmap);
f3=figure('Position',[100 100 1200 800],'color','w');
featureSubSpacePlot(spikeFeatures,idx,'hFigure',f3,'cmap',cmap);
if obj.saveFigures
figure(f2);
set(f2,'PaperPositionMode','auto');
if obj.fastPrinting
imwrite(frame2im(getframe(f2)),[obj.sortingDir '\Ch_' num2str(obj.chPar.s2r(i)) 'classification.jpeg'],'Quality',90);
else
print([obj.sortingDir '\Ch_' num2str(obj.chPar.s2r(i)) 'classification'],'-djpeg','-r300');
end
figure(f3);
set(f3,'PaperPositionMode','auto');
if obj.fastPrinting
imwrite(frame2im(getframe(f3)),[obj.sortingDir '\Ch_' num2str(obj.chPar.s2r(i)) 'featureSpace.jpeg'],'Quality',90);
else
print([obj.sortingDir '\Ch_' num2str(obj.chPar.s2r(i)) 'featureSpace'],'-djpeg','-r300');
end
if ishandle(f2)
close(f2);
end
if ishandle(f3)
close(f3);
end
end
end %plot initial classification classification
end %go over all channels
if ~obj.runWithoutSaving2File
save(obj.sortingFileNames.avgWaveformFile,'avgClusteredWaveforms','stdClusteredWaveforms','nAvgSpk');
end
end
function [obj,avgWF,stdWF,ch,mergedNeurons]=spikeMergingClustersNSK(obj)
avgWF=cell(1,obj.nCh);
stdWF=cell(1,obj.nCh);
ch=cell(1,obj.nCh);
isNoise=cell(1,obj.nCh);
if isempty(obj.sortingDir)
obj.runWithoutSaving2File=true;
else
obj.runWithoutSaving2File=false;
end
%steepness=1.2;
%nonLin = @(x) x+x./sqrt(1+abs(x).^steepness);
%x=-100:100;plot(x,nonLin(x));
load(obj.sortingFileNames.spikeDetectionFile{1},'postSpikeSamplesIntrp','preSpikeSamplesIntrp','preSpikeSamplesIntrp','obj.upSamplingFrequencySpike');
nSamples=postSpikeSamplesIntrp+preSpikeSamplesIntrp;
preSpikeSamples4NoiseDetection=obj.mergingPreSpike4NoiseDetection*preSpikeSamplesIntrp;
postSpikeSamples4NoiseDetection=obj.mergingPostSpike4NoiseDetection*postSpikeSamplesIntrp;
pSamples4NoiseDetection=round(nSamples/2-preSpikeSamples4NoiseDetection):round(nSamples/2+postSpikeSamples4NoiseDetection);
maxSpikeShiftSamples=obj.mergingMaxSpikeShift*obj.upSamplingFrequencySpike/1000;
load(obj.sortingFileNames.avgWaveformFile,'avgClusteredWaveforms','stdClusteredWaveforms','nAvgSpk');
validSamplesInTemplate=cell(1,obj.nCh);
isNoise=cell(1,obj.nCh);
for n=1:obj.nCh
for c=1:size(avgClusteredWaveforms{n},2)
tmpSpikes1=squeeze(avgClusteredWaveforms{n}(:,c,:));
tmpSpikesStd1=squeeze(stdClusteredWaveforms{n}(:,c,:));
pValid=~(tmpSpikes1>-obj.mergingNStdNoiseDetection*tmpSpikesStd1/sqrt(nAvgSpk{n}(c)-1) & tmpSpikes1<obj.mergingNStdNoiseDetection*tmpSpikesStd1/sqrt(nAvgSpk{n}(c)-1));
pValidReduces=pValid(pSamples4NoiseDetection,:);
validSamplesInTemplate{n}{c}=sum(pValidReduces(:))/numel(pValidReduces);
isNoise{n}{c}=validSamplesInTemplate{n}{c}<obj.mergingNoiseThreshold;
%{
z
if validSamplesInTemplate{n}(c)<obj.mergingNoiseThreshold
tVec=(1:numel(tmpSpikes1))/60;
plot(tVec(~pValid),tmpSpikes1(~pValid),'o','color',[0.6 0.6 0.9]);hold on;
plot(tVec,tmpSpikes1(:),'lineWidth',1);
ylabel('Voltage [\muV]');
xlabel('Time [ms]');
axis tight;
title([num2str([n c]) '- noise = ' num2str(validSamplesInTemplate{n}{c}<0.4) ', score = ' num2str(validSamplesInTemplate{n}{c})]);
pause;hold off;
end
%}
end
end
%hist(cell2mat(cellfun(@(x) x(:),{validSamplesInTemplate{:}},'UniformOutput',0)'),100);
lags=-maxSpikeShiftSamples:maxSpikeShiftSamples;
merge=cell(obj.nCh,obj.nCh);
fracValSamples=cell(obj.nCh,obj.nCh);
fprintf('Calculating channels to merge');
mergingList=[];
for n1=1:obj.nCh
for n2=obj.chPar.surChExtVec{n1}(obj.chPar.pSurChOverlap{n1})
for c1=1:size(avgClusteredWaveforms{n1},2)
for c2=1:size(avgClusteredWaveforms{n2},2)
tmpSpikes1=squeeze(avgClusteredWaveforms{n1}(:,c1,obj.chPar.pSharedCh1{n1}{n2}));
tmpSpikes2=squeeze(avgClusteredWaveforms{n2}(:,c2,obj.chPar.pSharedCh2{n1}{n2}));
tmpSpikesStd1=squeeze(stdClusteredWaveforms{n1}(:,c1,obj.chPar.pSharedCh1{n1}{n2}));
tmpSpikesStd2=squeeze(stdClusteredWaveforms{n2}(:,c2,obj.chPar.pSharedCh2{n1}{n2}));
if obj.mergingAllignWaveShapes
% Compute cross-correlation
X=nanmean(xcorrmat(tmpSpikes1,tmpSpikes2,maxSpikeShiftSamples),2);
[cc,d]=max(X);
d=lags(d);
if d>0
tmpSpikes1=tmpSpikes1(d+1:end,:);
tmpSpikes2=tmpSpikes2(1:end-d,:);
tmpSpikesStd1=tmpSpikesStd1(d+1:end,:);
tmpSpikesStd2=tmpSpikesStd2(1:end-d,:);
else
tmpSpikes1=tmpSpikes1(1:end+d,:);
tmpSpikes2=tmpSpikes2(-d+1:end,:);
tmpSpikesStd1=tmpSpikesStd1(1:end+d,:);
tmpSpikesStd2=tmpSpikesStd2(-d+1:end,:);
end
end
%remove from spike shapes all noise points estimated as points below obj.mergingNStdSpikeDetection standard deviation
pValid=~( tmpSpikes1>-obj.mergingNStdSpikeDetection*tmpSpikesStd1/sqrt(nAvgSpk{n1}(c1)-1) & tmpSpikes1<obj.mergingNStdSpikeDetection*tmpSpikesStd1/sqrt(nAvgSpk{n1}(c1)-1) & ...
tmpSpikes2>-obj.mergingNStdSpikeDetection*tmpSpikesStd2/sqrt(nAvgSpk{n2}(c2)-1) & tmpSpikes2<obj.mergingNStdSpikeDetection*tmpSpikesStd2/sqrt(nAvgSpk{n2}(c2)-1) );
%tmpSpikes1(tmpSpikes1>-obj.mergingNStdSpikeDetection*tmpSpikesStd1/sqrt(nAvgSpk{n1}(c1)-1) & tmpSpikes1<obj.mergingNStdSpikeDetection*tmpSpikesStd1/sqrt(nAvgSpk{n1}(c1)-1))=NaN;
%tmpSpikes2(tmpSpikes2>-obj.mergingNStdSpikeDetection*tmpSpikesStd2/sqrt(nAvgSpk{n2}(c2)-1) & tmpSpikes2<obj.mergingNStdSpikeDetection*tmpSpikesStd2/sqrt(nAvgSpk{n2}(c2)-1))=NaN;
%tmpScore=nanmedian(((tmpSpikes1(:)-tmpSpikes2(:)).^2)./(tmpSpikesStd1(:).^2+tmpSpikesStd2(:).^2));
%merge{n1,n2}(c1,c2)=nanmean(((tmpSpikes1(:)-tmpSpikes2(:)).^2))./sqrt(nanvar(tmpSpikes1(:))+nanvar(tmpSpikes2(:)));
merge{n1,n2}(c1,c2)=nanmean((tmpSpikes1(pValid)-tmpSpikes2(pValid)).^2)./(nanvar(tmpSpikes1(pValid))+nanvar(tmpSpikes2(pValid)));
fracValSamples{n1,n2}(c1,c2)=sum(pValid(:))/numel(pValid);
if merge{n1,n2}(c1,c2)<=obj.mergingThreshold
mergingList=[mergingList; [n1 c1 n2 c2 d]];
end
if obj.mergingTestInitialTemplateMerging
%plotting tests
f=figure;
tVec=(1:numel(tmpSpikes1))/60;
plot(tVec(~pValid),tmpSpikes1(~pValid),'o','color',[0.6 0.6 0.9]);hold on;
plot(tVec(~pValid),tmpSpikes2(~pValid),'o','color',[0.9 0.6 0.6]);
plot(tVec,tmpSpikes1(:),'lineWidth',1);hold on;
plot(tVec,tmpSpikes2(:),'r','lineWidth',1);
ylabel('Voltage [\muV]');
xlabel('Time [ms]');
axis tight;
title(['merge=' num2str(merge{n1,n2}(c1,c2)),', T=' num2str(obj.mergingThreshold)]);
%pause;hold off;
end
end
end
end
end
if obj.mergingPlotStatistics
f=figure;
%hist(cell2mat(cellfun(@(x) x(:),{fracValSamples{:}},'UniformOutput',0)'),100);
hist(cell2mat(cellfun(@(x) x(:),{merge{:}},'UniformOutput',0)'),100);
line([obj.mergingThreshold obj.mergingThreshold],ylim,'color','r');
ylabel('# pairs');
xlabel('merging score');
print([obj.sortingDir filesep 'mergingStats'],'-djpeg','-r300');
close(f);
end
%
%change the format of the average waveform to support different number of channels for different neurons on the same electrode
for i=1:1:obj.nCh
for j=1:size(avgClusteredWaveforms{i},2)
avgWF{i}{j}=squeeze(avgClusteredWaveforms{i}(:,j,:));
stdWF{i}{j}=squeeze(avgClusteredWaveforms{i}(:,j,:));
ch{i}{j}=obj.chPar.surChExtVec{i};
end
end
%f=figure;h=axes;activityTracePhysicalSpacePlot(h,ch{i}{j},avgWF{i}{j}',obj.chPar.rEn);
%mergingList=sortrows(mergingList); %rearrange mergingList according ascending order of 1st then 2nd,3rd,4th rows
if ~isempty(mergingList)
mergingList = unique(mergingList,'rows'); %finds the unique rows to merge
uniqueNeurons=unique([mergingList(:,[1 2]);mergingList(:,[3 4])],'rows');
%associate each neuron with a number and recalculate list
tmpList=[mergingList(:,[1 2]);mergingList(:,[3 4])];
numericMergeList=zeros(size(tmpList,1),1);
for i=1:size(uniqueNeurons,1)
p=find(tmpList(:,1)==uniqueNeurons(i,1) & tmpList(:,2)==uniqueNeurons(i,2));
numericMergeList(p)=i;
end
numericMergeList=reshape(numericMergeList,[numel(numericMergeList)/2 2]);
%numericDelayList = sparse([numericMergeList(:,1);numericMergeList(:,2)],[numericMergeList(:,2);numericMergeList(:,1)],[mergingList(:,5);-mergingList(:,5)]);
%group connected neurons (the numeric representation corresponds to the order in uniqueNeurons
groups=groupPairs(numericMergeList(:,1),numericMergeList(:,2));
%A good alternative to recalculating templates would be to take the neuron with most spikes, this solution will not require any reloading of data,
%however, it should be checked how reliable is this relative to the recalculation solution
%Recalculate templates
neurons2Remove=[]; %initialization
if obj.mergingRecalculateTemplates
nGroups=numel(groups);
fprintf('Recalculating templates for merged groups (/%d)',nGroups);
neuron=cell(1,nGroups); %initialization
for i=1:nGroups
fprintf('%d ',i);
newSpikeShapes=[]; %initialization
minimalChMap=1:obj.nCh;
maximalChMap=[];
nNeuron=numel(groups{i});
neuron{i}=zeros(nNeuron,2); %initialization
channel=zeros(nNeuron,1); %initialization
%tmpDelays=full(numericDelayList(groups{i},groups{i})); %extract the relative delays as previously calculated using the cross correlation function
%first get the neuron numbers from groups and calculat the common channels (minimalChMap) that are recorded in all these groups
for j=1:nNeuron
neuron{i}(j,:)=uniqueNeurons(groups{i}(j),:);
channel(j)=neuron{i}(j,1);
noiseSpike(j)=isNoise{neuron{i}(j,1)}{neuron{i}(j,2)};
[commonCh,pComN1,pComN2]=intersect(minimalChMap,obj.chPar.surChExtVec{channel(j)});
minimalChMap=minimalChMap(pComN1);
maximalChMap=[maximalChMap obj.chPar.surChExtVec{channel(j)}];
end
maxChMap=unique(maximalChMap);
if isempty(minimalChMap) %some noise signal will be similar across the whole array and will come out as one group -in this case the whole array is used for comparison
fprintf('Group %d (%d neurons) not merged due to lack of common channes',i,nNeuron);
minimalChMap=maximalChMap;
else
avgWaveformsSimilarNeurons=nan(nSamples,nNeuron,numel(minimalChMap));
for j=1:nNeuron
[commonCh,pComN1,pComN2]=intersect(minimalChMap,obj.chPar.surChExtVec{channel(j)});
avgWaveformsSimilarNeurons(:,j,pComN1)=avgClusteredWaveforms{channel(j)}(:,neuron{i}(j,2),pComN2);
end
[allignedWaveforms,tmpDelays]=allignSpikeShapes(permute(avgWaveformsSimilarNeurons,[2 3 1]));
%load the waveform shift them according to cross correlation delays and add them to one big spike waveform matrix
for j=1:nNeuron
[commonCh,pComN1,pComN2]=intersect(minimalChMap,obj.chPar.surChExtVec{channel(j)});
%try to rewrite this part to access only specific indices in spike shapes by defining a matfile object (maybe this can increase speed)
load(obj.sortingFileNames.spikeDetectionFile{channel(j)},'spikeShapes','preSpikeSamplesIntrp','minimumDetectionIntervalSamplesIntrp','detectionInt2uV');
load(obj.sortingFileNames.clusteringFile{channel(j)},'idx');
pRelevantSpikes=find(idx==neuron{i}(j,2));
%tmpSpikeShapes=detectionInt2uV.*double(spikeShapes(: , pRelevantSpikes , pComN2)); %there is a conversion 2 double here - if no shifting is needed
tmpSpikeShapes=nan(nSamples,numel(pRelevantSpikes),numel(minimalChMap)); %there is a conversion 2 double here
if tmpDelays(j) >= 0
tmpSpikeShapes(1 : (end-tmpDelays(j)) , : , pComN1)=detectionInt2uV.*double(spikeShapes( (tmpDelays(j)+1) : end , pRelevantSpikes , pComN2)); %there is a conversion 2 double here
else
tmpSpikeShapes(-tmpDelays(j)+1:end,:,pComN1)=detectionInt2uV.*double(spikeShapes(1:end+tmpDelays(j),pRelevantSpikes,pComN2)); %there is a conversion 2 double here
end
%}
newSpikeShapes=cat(2,newSpikeShapes,tmpSpikeShapes);
end
%newSpikeShapes=permute(allignSpikeShapes(permute(newSpikeShapes,[2 3 1])),[3 1 2]);
%newSpikeShapes(newSpikeShapes==0)=NaN; %to not include in the averages the padding due to spike shifting
avgSpikeWaveforms=nanmedian(newSpikeShapes,2);
stdSpikeWaveforms=1.4826*nanmedian(abs(newSpikeShapes- bsxfun(@times,avgSpikeWaveforms,ones(1,size(newSpikeShapes,2),1)) ),2);
avgSpikeWaveforms(isnan(avgSpikeWaveforms))=0; %average waveform should not contain NaNs
stdSpikeWaveforms(isnan(stdSpikeWaveforms))=0; %average waveform should not contain NaNs
[~,pMin]=min(min(avgSpikeWaveforms((preSpikeSamplesIntrp-minimumDetectionIntervalSamplesIntrp):(preSpikeSamplesIntrp+minimumDetectionIntervalSamplesIntrp),:,:),[],1),[],3);
maxChannel=commonCh(pMin); %real channel
pMergedNeuron=find(channel==maxChannel,1,'first'); %select ch with largest amp. if there are several neurons on the same channel with the same waveform just takes the first
if isempty(pMergedNeuron) %for the special case where the peak waveform of the joined neurons sits on a different channel than any of the original neurons
pMergedNeuron=1;
fprintf('Maximum amplitude channel for group %d, channel %d, neuron %d - detected on a different channel than any of the original neurons before merging.\n',i,neuron{i}(:,1),neuron{i}(:,2));
end
%M=newSpikeShapes;plotShifted(reshape(permute(M,[1 3 2]),[size(M,1)*size(M,3) size(M,2)]),'verticalShift',30);line([(pMin-1)*size(M,1) pMin*size(M,1)],[0 0],'color','g','lineWidth',3);
%f=figure;h=axes;activityTracePhysicalSpacePlot(h,commonCh,squeeze(avgSpikeWaveforms(:,1,:))',obj.chPar.rEn);
%pause;hold off;
[commonCh,pComN1,pComN2]=intersect(minimalChMap,obj.chPar.surChExtVec{neuron{i}(pMergedNeuron,1)});
avgWF{neuron{i}(pMergedNeuron,1)}{neuron{i}(pMergedNeuron,2)}=avgSpikeWaveforms;
stdWF{neuron{i}(pMergedNeuron,1)}{neuron{i}(pMergedNeuron,2)}=stdSpikeWaveforms;
ch{neuron{i}(pMergedNeuron,1)}{neuron{i}(pMergedNeuron,2)}=commonCh;
isNoise{neuron{i}(pMergedNeuron,1)}{neuron{i}(pMergedNeuron,2)}=any(noiseSpike);
%collect all neurons to remove from waveforms
for j=[1:(pMergedNeuron-1) (pMergedNeuron+1):numel(groups{i})]
neurons2Remove=[neurons2Remove ; neuron{i}(j,:)];
end
end
end
%remove all merged waveforms
for i=unique(neurons2Remove(:,1))'
p=neurons2Remove(find(neurons2Remove(:,1)==i),2);
for j=1:numel(p)
avgWF{i}{p(j)}=[];
stdWF{i}{p(j)}=[];
ch{i}{p(j)}=[];
isNoise{i}{p(j)}=[];
end
if all(cellfun(@(x) isempty(x),avgWF{i})) %if after removal the channel i has no neurons this channel has to be replaced by an empty value
avgWF{i}=[];
stdWF{i}=[];
ch{i}=[];
isNoise{i}=[];
end
end
for i=find(cellfun(@(x) ~isempty(x),avgWF))
pEmptyNeurons=cellfun(@(x) isempty(x),avgWF{i});
avgWF{i}(pEmptyNeurons)=[];
stdWF{i}(pEmptyNeurons)=[];
ch{i}(pEmptyNeurons)=[];
isNoise{i}(pEmptyNeurons)=[];
end
else
% take the average of the largest group as the average
end
mergedNeurons=cellfun(@(x) [x(:,1) x(:,2)],neuron,'UniformOutput',0);
else
mergedNeurons=[];
end
if ~obj.runWithoutSaving2File
save(obj.sortingFileNames.mergedAvgWaveformFile,'avgWF','stdWF','ch','mergedNeurons','isNoise');
end
end
function [obj,t,ic]=spikeFittingNSK(obj)
if isempty(obj.sortingDir)
obj.runWithoutSaving2File=true;
else
obj.runWithoutSaving2File=false;
end
fprintf('\nFitting spikes...');
load(obj.sortingFileNames.mergedAvgWaveformFile,'avgWF','ch','isNoise');
%Plots for testing
%{
p=cellfun(@(x) numel(x),avgWF);
pV=find(p>0);
nPlots=min(ceil(sqrt(sum(p))),7);
c=1;
for i=1:numel(pV)
for j=1:p(pV(i))
h=subaxis(nPlots,nPlots,c,'S',0.03,'M',0.03);
activityTracePhysicalSpacePlot(h,ch{pV(i)}{j},0.03*squeeze(avgWF{pV(i)}{j})',obj.chPar.rEn,'scaling','none');
title([num2str(pV(i)) '-' num2str(j) ', noise=' num2str(isNoise{pV(i)}{j})]);
c=c+1;
end
end
%}
%get noise thresholds for all channels
for i=1:obj.nCh %go over all channels in the recording
load(obj.sortingFileNames.spikeDetectionFile{i},'Th');
T(i)=mean(Th);
end
edges=[0.5:1:20.5];
centers=(edges(1:end-1)+edges(2:end))/2;
gauss=@(x,m,s) (1/s/sqrt(2*pi)).*exp(-(x-m).^2./2./s.^2);
maxCCLag=obj.fittingMaxLag*obj.upSamplingFrequencySpike/1000;
lags=[fliplr(0:-obj.fittingLagIntervalSamples:-maxCCLag) obj.fittingLagIntervalSamples:obj.fittingLagIntervalSamples:maxCCLag];
lag=(numel(lags)-1)/2;
lagSamples=lags((lag+1):end);
tAll=cell(obj.nCh,1);
fprintf('\nFitting data on Ch (total %d): ',obj.nCh);
for i1=1:obj.nCh %go over all channels in the recording
fprintf('%d ',i1);
tmpWFs=[];
if ~exist(obj.sortingFileNames.spikeDetectionFile{i1},'file')
warning(['No spike detection file was found for Channel ' num2str(i1) '. Fitting not performed!']);
continue;
else
load(obj.sortingFileNames.spikeDetectionFile{i1},'spikeShapes','spikeTimes','detectionInt2uV');
[nSamples,nSpikes,nChTmp]=size(spikeShapes);
spikeShapes=double(spikeShapes)*detectionInt2uV;
end
noiseStd=std(spikeShapes([1:50 251:300],:,:),1);
[stdHist]=histc(squeeze(noiseStd),edges);
[~,pMax]=max(stdHist);
noiseStd=centers(pMax);
match=[];
correspCh=[];
for i2=[i1 obj.chPar.surChExtVec{i1}(obj.chPar.pSurChOverlap{i1})]
for neu=1:numel(avgWF{i2})
[commonCh,pComN1,pComN2]=intersect(obj.chPar.surChExtVec{i1},ch{i2}{neu});
%try and revise this criteria
tmpWF=avgWF{i2}{neu}(:,pComN2);
if obj.fittingTemplateMethod==1 %'minkowski' exp 2 + cross-corr
tmpWF=reshape(tmpWF,[size(tmpWF,1),size(tmpWF,2),1]);
spikes=permute(spikeShapes(:,:,pComN1),[1 3 2]);
tmpMatch=zeros(lag*2+1,nSpikes);
for l=0:lag
%spike is after the template
tmpMatch(lag+l+1,:)=( nanmean(reshape( bsxfun(@minus,spikes((lagSamples(l+1)+1):end,:,:),tmpWF(1:end-lagSamples(l+1),:,:)).^2 , [(nSamples-lagSamples(l+1))*numel(commonCh) nSpikes] )) ).^(1/2);
%spike is before the template
tmpMatch(lag-l+1,:)=( nanmean(reshape( bsxfun(@minus,spikes(1:end-lagSamples(l+1),:,:),tmpWF((lagSamples(l+1)+1):end,:,:)).^2 , [(nSamples-lagSamples(l+1))*numel(commonCh) nSpikes] )) ).^(1/2);
end
match=cat(3,match,tmpMatch);
correspCh=[correspCh ; i2 neu];
elseif obj.fittingTemplateMethod==2 %implement 'minkowski' distance with exponent 24
tmpWF=reshape(tmpWF,[size(tmpWF,1),1,size(tmpWF,2)]);
spikes=reshape(permute(spikeShapes(:,:,pComN1),[1 3 2]),[nSamples*numel(commonCh) nSpikes]);
tmpMatch=bsxfun(@minus,spikes,tmpWF(:)).^2;
tmpMatch=(nanmean(tmpMatch)).^(1/2);%./(nanstd(tmpMatch)+nanstd(tmpWF(:)));
match=[match ; tmpMatch];
correspCh=[correspCh ; i2 neu];
elseif obj.fittingTemplateMethod==3
tmpWF=reshape(tmpWF,[size(tmpWF,1),1,size(tmpWF,2)]);
spikes=reshape(permute(spikeShapes(:,:,pComN1),[1 3 2]),[nSamples*numel(commonCh) nSpikes]);
%noiseStd=mean(T)/5;
noiseStd=mean(noiseStd);
tmpMatch=bsxfun(@minus,spikes,tmpWF(:));
tmpMatch=gauss(0,0,noiseStd)-gauss(tmpMatch,0,noiseStd);
tmpMatch=mean(tmpMatch);
match=[match ; tmpMatch];
correspCh=[correspCh ; i2 neu];
end
%mS=mean(spikes(:));
%sS=std(spikes(:));
%spikes(spikes<mS+1*sS & spikes>mS-1*sS)=NaN;
%tmpWF(tmpWF<mS+1*sS & tmpWF>mS-1*sS)=NaN;
%{
f=figure('position',[83 120 1699 852]);
M=spikeShapes(:,:,pComN1);
plotShifted(reshape(permute(M,[1 3 2]),[size(M,1)*size(M,3) size(M,2)]),'verticalShift',30);hold on;
M=repmat(tmpWF,[1 nSpikes 1]);
plotShifted(reshape(permute(M,[1 3 2]),[size(M,1)*size(M,3) size(M,2)]),'verticalShift',30,'color','k');set(gca,'YTickLabel',[]);
text(zeros(nSpikes,1),0:30:(30*(nSpikes-1)),num2cell(tmpMatch),'HorizontalAlignment','right');
pause;close(f);
%}
end
end
if ~isempty(match)
if obj.fittingTemplateMethod==1
[tmpMin,delay]=min(match,[],1);
[~,idxOut]=min(tmpMin,[],3);
delay=lags(delay(sub2ind(size(delay),ones(1,nSpikes),1:nSpikes,idxOut)));
tAll{i1}=[correspCh(idxOut,[1 2]) (spikeTimes+delay/obj.upSamplingFrequencySpike*1000)'];
else
[~,idxOut]=min(match,[],1);
tAll{i1}=[correspCh(idxOut,[1 2]) spikeTimes'];
end
else
tAll{i1}=[];
end
%{
f=figure('position',[83 120 1699 852]);
[~,p]=sort(idxOut);
n=min(numel(p),50);
M=spikeShapes(:,p(1:n),:);
plotShifted(reshape(permute(M,[1 3 2]),[size(M,1)*size(M,3) size(M,2)]),'verticalShift',30);hold on;
text(zeros(n,1),0:30:(30*(n-1)),num2cell(idxOut(p(1:n))));
pause;close(f);
%}
%{
M=spikeShapes;plotShifted(reshape(permute(M,[1 3 2]),[size(M,1)*size(M,3) size(M,2)]),'verticalShift',30);line([(pMin-1)*size(M,1) pMin*size(M,1)],[0 0],'color','g','lineWidth',3);
f=figure;load layout_40_Hexa;En=flipud(En);
for c=1:20
h=subaxis(4,5,c,'S',0.03,'M',0.03);
activityTracePhysicalSpacePlot(h,commonCh,0.03*squeeze(spikeShapes(:,c,pComN1))',En,'scaling','none');
title(num2str(idxOut(c)));
end
%}
end
t=sortrows(cell2mat(tAll));
if ~isempty(t)
icCh = unique(t(:,[1 2]),'rows');
chTransitions=unique([find(diff(t(:,1))~=0);find(diff(t(:,2))~=0)]);
ic([1 2],:)=icCh';
ic(4,:)=[chTransitions' size(t,1)];
ic(3,:)=[1 ic(4,1:end-1)+1];
t=round((obj.upSamplingFrequencySpike/1000)*t(:,3)')/(obj.upSamplingFrequencySpike/1000); %make sure that the maximal resolution is in units of upSampling frequency
nNeurons=size(icCh,1);
%calculate matrix with all avg spike shapes
allWaveforms=nan(nSamples,nNeurons,obj.nCh);
for i=1:nNeurons
allWaveforms(:,i,ch{ic(1,i)}{ic(2,i)})=avgWF{ic(1,i)}{ic(2,i)};
isNoiseAll(i)=isNoise{ic(1,i)}{ic(2,i)};
end
%{
f=figure;
[h,hParent]=spikeDensityPlotPhysicalSpace([],obj.upSamplingFrequencySpike,obj.chPar.s2r,obj.chPar.En,...
'hParent',f,'avgSpikeWaveforms',allWaveforms,'logDensity',true);
%}
%move back to original channel numbers
ic(1,:)=obj.chPar.s2r(ic(1,:));
else
ic=[];
allWaveforms=[];
isNoiseAll=[];
end
if ~obj.runWithoutSaving2File
save(obj.sortingFileNames.fittingFile,'t','ic','allWaveforms','isNoiseAll');
end
end
function obj=spikePostProcessingNSK(obj)
if obj.postPlotRawLongWaveforms || obj.postPlotFilteredWaveforms %calculate figure aspect ration based on electrode layout
[mElecs,nElecs]=size(obj.chPar.En);
figurePosition=[100 50 min(1000,100*mElecs) min(900,100*nElecs)];
end
fprintf('\nPost processing spikes...');
load(obj.sortingFileNames.fittingFile,'t','ic','allWaveforms','isNoiseAll');
[nSamples,nNeurons,obj.nCh]=size(allWaveforms);
if ~isempty(ic)
neuronNames=ic(1:2,:);
else
neuronNames=[];
end
if isempty(obj.filterObj)
obj=getHighpassFilter(obj);
end
nonExistingNeurons=find(~obj.sortingFileNames.postProcessignAnalysisExist);
nNonExistingNeurons=numel(nonExistingNeurons);
matFileObj = matfile(obj.sortingFileNames.postProcessignAnalysisFile,'Writable',true);
if nNonExistingNeurons==1 %there is no matfile
matFileObj.postProcessignAnalysisExist=zeros(1,nNeurons);
nonExistingNeurons=zeros(1,nNeurons);
nNonExistingNeurons=numel(nonExistingNeurons);
matFileObj.neuronNames=neuronNames;
if obj.postExtractRawLongWaveformsFromSpikeTimes
matFileObj.avgLongWF=zeros(nNeurons,obj.nCh,obj.postTotalRawWindow*obj.dataRecordingObj.samplingFrequency/1000);
matFileObj.stdLongWF=zeros(nNeurons,obj.nCh,obj.postTotalRawWindow*obj.dataRecordingObj.samplingFrequency/1000);
matFileObj.PSDSNR=zeros(1,nNeurons);
else
matFileObj.avgLongWF=[];matFileObj.stdLongWF=[];matFileObj.PSDSNR=[];
end
if obj.postExtractFilteredWaveformsFromSpikeTimes
matFileObj.avgFinalWF=zeros(nNeurons,obj.nCh,obj.postTotalFilteredWindow*obj.dataRecordingObj.samplingFrequency/1000);
matFileObj.stdFinalWF=zeros(nNeurons,obj.nCh,obj.postTotalFilteredWindow*obj.dataRecordingObj.samplingFrequency/1000);
matFileObj.spkSNR=zeros(1,nNeurons);
matFileObj.nSpkTotal=zeros(1,nNeurons);
matFileObj.spkMaxAmp=zeros(1,nNeurons);
else
matFileObj.avgFinalWF=[];matFileObj.stdFinalWF=[];matFileObj.spkSNR=[];matFileObj.nSpkTotal=[];matFileObj.spkMaxAmp=[];
end
end
for i=nonExistingNeurons
tTmp=t(ic(3,i):ic(4,i));
nSpk=numel(tTmp);
if obj.postExtractRawLongWaveformsFromSpikeTimes
[V_uV,T_ms]=obj.dataRecordingObj.getData(obj.dataRecordingObj.channelNumbers,tTmp(1:min(nSpk,obj.postMaxSpikes2Present))-obj.postPreRawWindow,obj.postTotalRawWindow);
%standard deviation
stdLongWF(1,:,:)=squeeze(std(V_uV,[],2));
%average substruction mean
avgLongWF(1,:,:)=squeeze(median(V_uV,2));
tmpMean=mean(avgLongWF(1,:,T_ms<(obj.postPreRawWindow-1)),3); %calculate the baseline according to the average voltage before the spike (1ms before spike peak)
V_uV=bsxfun(@minus,V_uV,tmpMean');
avgLongWF(1,:,:)=bsxfun(@minus,avgLongWF(1,:,:),tmpMean);
p=find(T_ms>obj.postRawSNRStartEnd(1) & T_ms<=obj.postRawSNRStartEnd(2)); %find relevant time points surrounding the spike
tmpSNR=avgLongWF(1,obj.chPar.surChExtVec{ic(1,i)},p)./stdLongWF(1,obj.chPar.surChExtVec{ic(1,i)},p); %calculated at the position of the extended grid
if obj.postPlotRawLongWaveforms
f=figure('position',figurePosition);
neuronString=['Neu ' num2str(ic(1,i)) '-' num2str(ic(2,i))];
infoStr={['nSpk=' num2str(nSpk)],neuronString,['noise=' num2str(isNoiseAll(i))]};
[h,hParent]=spikeDensityPlotPhysicalSpace(permute(V_uV,[3 2 1]),obj.dataRecordingObj.samplingFrequency,obj.chPar.s2r,obj.chPar.En,...
'hParent',f,'avgSpikeWaveforms',permute(avgLongWF(1,:,:),[3 1 2]),'logDensity',true);
annotation('textbox',[0.01 0.89 0.1 0.1],'FitHeightToText','on','String',infoStr);
printFile=[obj.sortingDir filesep 'neuron' neuronString '-spikeShapeRaw'];
set(f,'PaperPositionMode','auto');
print(printFile,'-djpeg','-r300');
close(f);
end
matFileObj.PSDSNR(1,i)=mean(abs(tmpSNR(:)));
matFileObj.stdLongWF(i,:,:)=stdLongWF;
matFileObj.avgLongWF(i,:,:)=avgLongWF;
end
if obj.postExtractFilteredWaveformsFromSpikeTimes
if obj.postExtractRawLongWaveformsFromSpikeTimes
[V_uV,T_ms]=obj.filterObj.getFilteredData(V_uV( : , : , T_ms>=(obj.postPreRawWindow-obj.postPreFilteredWindow) & T_ms<(obj.postPreRawWindow-obj.postPreFilteredWindow+obj.postTotalFilteredWindow)));
else
[V_uV,T_ms]=obj.filterObj.getFilteredData(obj.dataRecordingObj.getData(obj.dataRecordingObj.channelNumbers,tTmp(1:min(nSpk,obj.postMaxSpikes2Present))-obj.postPreFilteredWindow,obj.postTotalFilteredWindow));
end
avgFinalWF(1,:,:)=squeeze(median(V_uV,2));
stdFinalWF(1,:,:)=squeeze(std(V_uV,[],2));
p=find(T_ms>obj.postFilteredSNRStartEnd(1) & T_ms<=obj.postFilteredSNRStartEnd(2)); %find relevant time points surrounding the spike
tmpSNR=avgFinalWF(1,obj.chPar.surChExtVec{ic(1,i)}(obj.chPar.pSurCh{ic(1,i)}),p)./stdFinalWF(1,obj.chPar.surChExtVec{ic(1,i)}(obj.chPar.pSurCh{ic(1,i)}),p); %calculated at the positions of the surrounding grid
if obj.postPlotFilteredWaveforms
f=figure('position',figurePosition);
neuronString=['Neu ' num2str(ic(1,i)) '-' num2str(ic(2,i))];
infoStr={['nSpk=' num2str(nSpk)],neuronString,['noise=' num2str(isNoiseAll(i))]};
[h,hParent]=spikeDensityPlotPhysicalSpace(permute(V_uV,[3 2 1]),obj.dataRecordingObj.samplingFrequency,obj.chPar.s2r,obj.chPar.En,...
'hParent',f,'avgSpikeWaveforms',permute(avgFinalWF(1,:,:),[3 1 2]),'logDensity',true);
annotation('textbox',[0.01 0.89 0.1 0.1],'FitHeightToText','on','String',infoStr);
%print to file
printFile=[obj.sortingDir filesep 'neuron' neuronString '-spikeShape'];
set(f,'PaperPositionMode','auto');
print(printFile,'-djpeg','-r300');
close(f);
end
matFileObj.avgFinalWF(i,:,:)=avgFinalWF;
matFileObj.stdFinalWF(i,:,:)=stdFinalWF;
matFileObj.spkSNR(1,i)=abs(mean(tmpSNR(:)));
matFileObj.spkMaxAmp(1,i)=max(max(abs(avgFinalWF))); %the extremal spike amplitude
matFileObj.nSpkTotal(1,i)=ic(4,i)-ic(3,i)+1;
end
matFileObj.postProcessignAnalysisExist(1,i)=1;
end %loop over all neurons
if obj.postPlotAllAvgSpikeTemplates
nPlotAxis=min(ceil(sqrt(nNeurons)),5);
nPlotsPerFigure=nPlotAxis.^2;
if nNeurons>0
for i=1:(nNeurons+1)
if mod((i-1),nPlotsPerFigure)==0 || i==(nNeurons+1)
if i~=1
printFile=[obj.sortingDir filesep 'avgSpikeShapes' num2str(ceil((i-1)/nPlotsPerFigure))];
if obj.fastPrinting
imwrite(frame2im(getframe(f)),[printFile '.jpeg'],'Quality',90);
else
set(f,'PaperPositionMode','auto');
print(printFile,'-djpeg','-r300');
end
close(f);
end
if i~=(nNeurons+1)
f=figure('position',figurePosition);
end
end
if i<nNeurons
neuronString=['Neu ' num2str(ic(1,i)) '-' num2str(ic(2,i)),',N' num2str(isNoiseAll(i))];
h=subaxis(f,nPlotAxis,nPlotAxis,mod(i,nPlotsPerFigure),'S',0.02,'M',0.02);
activityTracePhysicalSpacePlot(h,obj.chPar.s2r,0.03*squeeze(allWaveforms(:,i,:))',obj.chPar.En,'scaling','none');
text(0,0,neuronString);
elseif i==nNeurons
neuronString=['Neu ' num2str(ic(1,i)) '-' num2str(ic(2,i)),',N' num2str(isNoiseAll(i))];
h=subaxis(f,nPlotAxis,nPlotAxis,nNeurons,'S',0.02,'M',0.02);
activityTracePhysicalSpacePlot(h,obj.chPar.s2r,0.03*squeeze(allWaveforms(:,i,:))',obj.chPar.En,'scaling','none');
text(0,0,neuronString);
end
end
end
end
if obj.postPlotSpikeReliability
if nNeurons>0
f=figure('position',[300 50 900 700]);
mNSpkTotal=mean(matFileObj.nSpkTotal);
%generate size legend
plotSpikeSNR=[matFileObj.spkSNR (max(matFileObj.spkSNR)-(max(matFileObj.spkSNR)-min(matFileObj.spkSNR))*0.01)*ones(1,5)];
plotPSDSNR=[matFileObj.PSDSNR [0.9 1 1.1 1.2 1.3]*((max(matFileObj.PSDSNR)+min(matFileObj.PSDSNR))/2)];
legendSpikeNums=round([mNSpkTotal/6 mNSpkTotal/3 mNSpkTotal mNSpkTotal*3 mNSpkTotal*6]);
plotnSpkTotal=[matFileObj.nSpkTotal legendSpikeNums];
plotspkMaxAmp=[matFileObj.spkMaxAmp ones(1,5)*min(matFileObj.spkMaxAmp)];
scatter(plotSpikeSNR,plotPSDSNR,(plotnSpkTotal/mNSpkTotal)*50+5,plotspkMaxAmp,'linewidth',2);
text(plotSpikeSNR(end-4:end)*1.02,plotPSDSNR(end-4:end),num2str(legendSpikeNums'/(obj.dataRecordingObj.recordingDuration_ms/1000),3),'FontSize',8)
xlabel('$$\sqrt{SNR_{spike}}$$','Interpreter','latex','FontSize',14);ylabel('$$\sqrt{SNR_{PSD}}$$','Interpreter','latex','FontSize',14);
cb=colorbar('position',[0.8511 0.6857 0.0100 0.2100]);ylabel(cb,'Max spk. amp.');
printFile=[obj.sortingDir filesep 'SNR_spikePSD'];
set(f,'PaperPositionMode','auto');
print(printFile,'-djpeg','-r300');
end
end
end
%Plots for testing
%{
for i=1:numel(pV)
for j=1:p(pV(i))
h=subaxis(nPlots,nPlots,c,'S',0.03,'M',0.03);
activityTracePhysicalSpacePlot(h,ch{pV(i)}{j},0.03*squeeze(avgWF{pV(i)}{j})',obj.chPar.rEn,'scaling','none');
title([num2str(pV(i)) '-' num2str(j)]);
c=c+1;
end
end
%}
function [publicProps]=getProperties(obj)
metaClassData=metaclass(obj);
allPropName={metaClassData.PropertyList.Name}';
allPropSetAccess={metaClassData.PropertyList.SetAccess}';
publicProps(:,1)=allPropName(find(strcmp(allPropSetAccess, 'public')));
%collect all prop values
for i=1:numel(publicProps)
publicProps{i,2}=obj.(allPropName{i});
end
end
end
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%%%%% Additional functions %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
methods (Static)
function [clusterMerged,Merge]=SpikeTempDiffMerging(spikeShapes,Clusters,Templates,crit)
%merge clusters in a group that have to be merge based on the residuals
%between spikes and templates.
%synthax : [clusterMerged,Merge]=SpikeTempDiffMerging(spikeShapes,Clusters,Templates,crit)
%input:
% - spikeShapes : all the spikeShapes Nspikes-sizeSpike-NbChannels
% - Clusters : vector containing the index of spikes in spikeShapes
% - Templates : The templates corresponding to the clusters
% nChannels-sizeTemplate-nTemplates
% - crit : the value of the threshold to merge clusters(default value : 1.5)
%output :
% - clusterMerged:vector of symbolic merging
% - Merge : binary symmetric merging decision matrix
% nclust-nclust.
if nargin<4
crit=1.5;
end
numClust=size(Templates,3);
clusterMerged=1:numClust; % a group is assigned to every cluster
%build ajacent matrix of clusters
Merge=zeros(numClust,numClust);
for i=1:numClust
for j=1:numClust
if j~=i
clSel=find(Clusters==i);
cTmp1=Templates(:,:,i)';
cTmp2=Templates(:,:,j)';
diffSpikeTemp1=zeros(1,length(clSel));diffSpikeTemp2=zeros(1,length(clSel));
for k=1:length(clSel)
cSpike=reshape(spikeShapes(clSel(k),:,:),size(spikeShapes,2),size(spikeShapes,3));
diffSpikeTemp1(k)=sum((cTmp1(:)-cSpike(:)).^2);
diffSpikeTemp2(k)=sum((cTmp2(:)-cSpike(:)).^2);
end
Merge(i,j)=median(diffSpikeTemp2)/median(diffSpikeTemp1);
end
end
end
%Symetrize matrix
for k=1:numClust
for m=k+1:numClust
if Merge(k,m)<crit &&Merge(m,k)<crit
Merge(k,m)=1;Merge(m,k)=1;
clusterMerged(clusterMerged==k)=clusterMerged(m);
else
Merge(k,m)=0;Merge(m,k)=0;
end
end
end
end
function [gc,f]=projectionMerge(spikeFeatures,initIdx,varargin)
%merge clusters in a group that have to be merge based on the residuals
%between spikes and templates.
%synthax : [gc,f]=projectionMerge(spikeFeatures,initIdx,varargin)
%input:
% - spikeFeatures : spike features ()
% - initIdx : the clusters index of every spike
% vararin - 'property','value'
%output :
% - gc: a binary matrix with ones indicating a necessary merge
% - f: a figure handle for the generated plot
%default variables
obj.clusteringMinNSpikesCluster=10;
obj.clusteringSTDMergeFac=2;
obj.clusteringMergeThreshold=0.18;
obj.clusteringPlotProjection=1;
%Collects all options
for i=1:2:length(varargin)
eval([varargin{i} '=' 'varargin{i+1};'])
end
%find robust cluster centers
nClustersIn=numel(unique(initIdx));
for k=1:nClustersIn
cent(k,:)=median(spikeFeatures(initIdx==k,:));
end
if nClustersIn<=1
gc=1;
f=[];
return;
end
if obj.clusteringPlotProjection
f=figure('Position',[50 50 1400 900]);
else
f=[];
end
D=zeros(nClustersIn);
groups=mat2cell(1:nClustersIn,1,ones(1,nClustersIn));
gc=1:nClustersIn; % a group is assigned to every cluster
for k=2:nClustersIn
for m=1:k-1
v=(cent(k,:)-cent(m,:));
p1=projectionND(v,spikeFeatures(initIdx==k,:));
p2=projectionND(v,spikeFeatures(initIdx==m,:));
pCent=projectionND(v,[cent(k,:);cent(m,:)]);%for plotting purpuses
if numel(p1)<obj.clusteringMinNSpikesCluster || numel(p2)<obj.clusteringMinNSpikesCluster
D(k,m)=0;
std_p1=[];
std_p2=[];
v=[];
else
mp1=median(p1);
mp2=median(p2);
std_p1=median(abs( p1-mp1 ),2) / 0.6745;
std_p2=median(abs( p2-mp2 ),2) / 0.6745;
%std_p1=std(p1);
%std_p2=std(p1);
nV=numel(p1)+numel(p2);
s=sign(mp1-mp2);
%edges=[(mp2-std_p2*s):(s*10/nV*abs(mp1-mp2)):(mp1+std_p1*s)];
edges=(mp2-std_p2*s):(((mp1+std_p1*s)-(mp2-std_p2*s))/(round(log(nV))/10)/20):(mp1+std_p1*s);
%eges must be divided in a way that preserve the extreme edges on both sides
n1=histc(p1,edges);
n2=histc(p2,edges);
n=n2(1:end-1)-n1(1:end-1); %eliminate edges that sum over
%edges=edges(1:end-1);
firstCross=find(n(2:end)<=0 & n(1:end-1)>0,1,'first')+1;
secondCross=find(n(1:end-1)>=0 & n(2:end)<0,1,'last');
intersection=(edges(firstCross) + edges(secondCross))/2;
%D(k,m)=max(sum(p2>(intersection-std_p2/obj.clusteringSTDMergeFac))/sum(n2),sum(p1<(intersection+std_p1/obj.clusteringSTDMergeFac))/sum(n1));
%D(k,m)=sum([p1 p2]<(intersection+std_p2/obj.clusteringSTDMergeFac) & [p1 p2]>(intersection-std_p1/obj.clusteringSTDMergeFac))/sum([n1 n2]);
if isempty(intersection) %one cluster is contained within the other
D(k,m)=1;
else
D(k,m)=( sum(p2>(intersection-std_p2/obj.clusteringSTDMergeFac)) + sum(p1<(intersection+std_p1/obj.clusteringSTDMergeFac)) ) /sum([n1 n2] );
end
%figure;plot([p1 p2]);hold on;plot(ones(1,numel(edges)),edges,'or');line([1 sum([n1 n2])],[intersection intersection],'color','k','LineWidth',2);
end
%D(k,m)=(sqrt(sum(v.^2))/sqrt(std_p1^2 + std_p2.^2));
%D(k,m)=(sqrt(sum(v.^2))/(std_p + std_p2);
%D(k,m)=(1+skewness([p1 p2])^2)/(kurtosis([p1 p2])+3);
%D(k,m) = kstest2(p1,p2,[],0.0.05);
if D(k,m)>obj.clusteringMergeThreshold
%find in which group k is and add all is group to m
groupOfK=gc(k);
groupOfM=gc(m);
gc(gc==groupOfK)=groupOfM;
end
if obj.clusteringPlotProjection
subaxis(nClustersIn,nClustersIn,(m-1)*nClustersIn+k,'Spacing', 0.001, 'Padding', 0.001, 'Margin', 0.001);
edges=[min([p1 p2]):((max([p1 p2])-min([p1 p2]))/30):max([p1 p2])]; %edges different from before
n1=hist(p1,edges);
n2=hist(p2,edges);
bar(edges,[n1;n2]',1,'stacked');
axis tight;
set(gca,'XTickLabel',[],'YTickLabel',[]);
subaxis(nClustersIn,nClustersIn,(k-1)*nClustersIn+m,'Spacing', 0.001, 'Padding', 0.001, 'Margin', 0.001);
strTxt={['F=' num2str(D(k,m))],['s1=' num2str(std_p1)],['s2=' num2str(std_p2)],['D=' num2str(sqrt(sum(v.^2)))]};
text(0,0.5,strTxt);
axis off;
end
end
end
function p=projectionND(v,d)
%Calculate projection between a vector and a set of dots in multi dimensional space
%v = [1 x N] - vector
%d = [M X N] - M dot locations
%calculate the cos angle between vector and dots
cosAng=v*d'./(sqrt(sum(v.^2))*sqrt(sum(d'.^2)));
p=cosAng.*sqrt(sum(d'.^2));
end
end
end %methods (static)
end