Skip to content

Commit

Permalink
Fixed: 1. ommittedFastaIds also prinds sequences indices; 2. comikl_m…
Browse files Browse the repository at this point in the history
…ain function correct handles the adjustment of testIndices
  • Loading branch information
snikumbh committed Apr 25, 2017
1 parent b438741 commit 9f9278e
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 20 deletions.
41 changes: 27 additions & 14 deletions comikl_main_with_weight_vector.m
Original file line number Diff line number Diff line change
Expand Up @@ -33,26 +33,31 @@

% Read positives
logMessages(debugMsgLocation,sprintf('Positives...'), debugLevel);
[nPos, posIndices, allSeqsRawPosFasta] = readFastaSequences(givenPosFastaFilename, chunkSizeInBps, givenNPos, outputFolder, debugLevel, debugMsgLocation);
[nPos, ommittedPosIndices, allSeqsRawPosFasta] = readFastaSequences(givenPosFastaFilename, chunkSizeInBps, givenNPos, outputFolder, debugLevel, debugMsgLocation);
% nPos and nNeg be could smaller than equal to givenNPos and givenNNeg resp. due to ommitting any sequences
% whose length is less than the chunkSizeInBps
%
% Read Negatives
logMessages(debugMsgLocation,sprintf('Negatives...'), debugLevel);
[nNeg, negIndices, allSeqsRawNegFasta] = readFastaSequences(givenNegFastaFilename, chunkSizeInBps, givenNNeg, outputFolder, debugLevel, debugMsgLocation);
[nNeg, ommittedNegIndices, allSeqsRawNegFasta] = readFastaSequences(givenNegFastaFilename, chunkSizeInBps, givenNNeg, outputFolder, debugLevel, debugMsgLocation);
% Combine positives + negatives
nBags = nPos + nNeg;
% Check if we need to update the testIndices if some sequences have been ommitted
% if posIndices and negIndices are empty, then no update required
if ~isempty(posIndices)
testIndices(find(ismember(testIndices, posIndices))) = [];
if ~isempty(ommittedPosIndices)
%nOmmitted = length(find(ismember(testIndices, ommittedPosIndices)));
testIndices(find(ismember(testIndices, ommittedPosIndices))) = [];
%logMessages(debugMsgLocation,sprintf('%d positive sequences ommitted from test set.\n', nOmmitted), debugLevel);
end
if ~isempty(negIndices)
testIndices(find(ismember(testIndices, nPos+negIndices))) = [];%offset by nPos
if ~isempty(ommittedNegIndices)
%nOmmitted = length(find(ismember(testIndices, ommittedNegIndices)));
testIndices(find(ismember(testIndices, nPos+ommittedNegIndices))) = [];%offset by nPos
%logMessages(debugMsgLocation,sprintf('%d negative sequences ommitted from test set.\n', nOmmitted), debugLevel);
end
% Additionally, one needs to account for the reduction in the number of sequences affecting the indices
% Additionally, one needs to account for the reduction in the number of sequences affecting the indices. Remove indices > nBags
%testIndices = testIndices(testIndices <= nBags);
for i=1:length(testIndices)
testIndices(i) = testIndices(i) - nnz(posIndices < testIndices(i)) - nnz(negIndices < testIndices(i));
testIndices(i) = testIndices(i) - nnz(ommittedPosIndices < testIndices(i)) - nnz( (nPos+ommittedNegIndices) < testIndices(i));
end
% Put them together
allSeqsRawFasta.sequence = cell(1, nBags);
Expand All @@ -64,8 +69,12 @@
allSeqsRawFasta.sequence{i+nPos} = allSeqsRawNegFasta.sequence{i};%offset by nPos
end
clear allSeqsRawNegFasta;
% Write test indices to disk for reproducibility
dlmwrite(strcat(outputFolder, '/testIndices.txt'), testIndices);
% Write the revised set test indices to disk for reproducibility

if ~isempty(ommittedPosIndices) | ~isempty(ommittedNegIndices)
sortedTestIndices = sort(testIndices);
dlmwrite(strcat(outputFolder, '/testIndices_New.txt'), sortedTestIndices');
end

% generate labels
% handle imbalance, multiply this to C for negatives
Expand Down Expand Up @@ -112,7 +121,7 @@
nBagsForTest = nBags-idx1;
logMessages(debugMsgLocation,sprintf('#Test: %d\n#Train: %d\n', nBagsForTest, nBagsForTrain), debugLevel);
logMessages(debugMsgLocation,sprintf('Segmentation statistics:\n'), debugLevel);
logMessages(debugMsgLocation,sprintf(' Mean number of instancesa in bags: %.2f\t Median: %d\t Max.: %d\t Min.: %d\n', mean(thisInstances), median(thisInstances), max(thisInstances), min(thisInstances)), debugLevel);
logMessages(debugMsgLocation,sprintf(' Mean number of instances in bags: %.2f\t Median: %d\t Max.: %d\t Min.: %d\n', mean(thisInstances), median(thisInstances), max(thisInstances), min(thisInstances)), debugLevel);
% Variable 'nBags' holds nBagsForTrain + nBagsForTest
%
% 2. Compute the instanceWide kernel that is used further
Expand Down Expand Up @@ -250,13 +259,13 @@
%
% Use the best param combinations to train using the whole trainiing set and predict the test set here
%
logMessages(debugMsgLocation,sprintf('---Test---\n'), debugLevel);
logMessages(debugMsgLocation,sprintf('---Re-training and Test---\n'), debugLevel);
% Re-train the model with all train instances together
% trainIndicesInBags is 1:nBagsForTrain
nClusters = bestParamComb.best_nClusters;
conformalXformationParam = bestParamComb.best_sigma;
logMessages(debugMsgLocation,sprintf('nClusters: %d, sigma: %d\n', nClusters, conformalXformationParam), debugLevel);
logMessages(debugMsgLocation,sprintf('Re-training uisng complete trainning set examples...\n'), debugLevel);
logMessages(debugMsgLocation,sprintf('Best param-values:\nSVM-Cost:%.3f, nClusters: %d, sigma: %.3f\n', nClusters, conformalXformationParam), debugLevel);
logMessages(debugMsgLocation,sprintf('Re-training using complete trainning set examples...\n'), debugLevel);

logMessages(debugMsgLocation,sprintf('Conformed Multi-instance kernel for all bags...\n'), debugLevel);
[allSeqsConformedSetKernel, rawConformedSetKernel, allSeqsTransformationKernel, clusterCentres] = ...
Expand Down Expand Up @@ -364,7 +373,11 @@
%
logMessages(debugMsgLocation,sprintf('Bias value: %.4f\n', biasValue), debugLevel);
yhat = yhatTemp + biasValue;
if(length(Youtertest) == length(yhat))
logMessages(debugMsgLocation, sprintf('%d and %d: Dimensions of yhat and ytest match!\n', length(Youtertest), length(yhat)), debugLevel);
end
dlmwrite(strcat(outputFolder,'/predictedLabelsWeightVector.txt'), yhat');
dlmwrite(strcat(outputFolder,'/givenLabels.txt'), Youtertest');
logMessages(debugMsgLocation,sprintf('Predicted labels written to disk. Computing the auROC/auPRC, this may take some time...\n'), debugLevel);
tic;
[test_teAUROC, test_teAUPRC] = libsvm_plotroc(Youtertest', yhat', 'personal');
Expand Down
2 changes: 1 addition & 1 deletion comikl_wrapper.m
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ case char('distal-proximal-cgi')
fprintf(testFid, 'testIndices for this outer fold:\n');
fclose(testFid);
if exist(testIndicesFilename, 'file') == 2
dlmwrite(testIndicesFilename, testIndices','-append', 'roffset', 1);
dlmwrite(testIndicesFilename, sort(testIndices'),'-append', 'roffset', 1);
logMessages(debugMsgLocation, sprintf('Test indices for this outer fold written to disk'), debugLevel);
end

Expand Down
10 changes: 5 additions & 5 deletions readFastaSequences.m
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
function [nSeqs, indices, passFasta] = readFastaSequences(txt, chunkSizeInBps, givenNSeqs, outputFolder, debugLevel, debugMsgLocation)
function [nSeqs, ommittedIndices, passFasta] = readFastaSequences(txt, chunkSizeInBps, givenNSeqs, outputFolder, debugLevel, debugMsgLocation)

% FUNCTION FASTA=READFASTA(FILENAME)
%
Expand All @@ -25,7 +25,7 @@
fasta.legend = cell(1, givenNSeqs);
fasta.sequence = cell(1, givenNSeqs);
i=0;
indices = [];
ommittedIndices = [];
fasta.title=[txt,', read: ',date];
% READS SEQUENCES AND THEIR LEGENDS
fid=fopen(txt,'r');
Expand All @@ -44,8 +44,8 @@
if lengths(end) >= chunkSizeInBps
fasta.sequence{i}=[fasta.sequence{i},upper(lineRead)];
else % we omit that sequence
indices = [indices i];
fprintf(ofid, fasta.legend{i});
ommittedIndices = [ommittedIndices i];
fprintf(ofid, '%d\t%s', i, fasta.legend{i});
fprintf(ofid, '\n');
end
end
Expand All @@ -54,7 +54,7 @@
fclose(ofid);
fclose(fid);
ineligibleLengths = lengths(lengths < chunkSizeInBps);
ineligibleIndices = find(lengths <= chunkSizeInBps);
ineligibleIndices = find(lengths < chunkSizeInBps);
nSeqs = givenNSeqs;
passFasta = fasta;
if size(ineligibleLengths, 2) > 0
Expand Down

0 comments on commit 9f9278e

Please sign in to comment.