Permalink
Cannot retrieve contributors at this time
Name already in use
A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
comik/libsvm_plotroc.m
Go to fileThis commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
97 lines (92 sloc)
3.51 KB
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
function [auROC auPRC] = plotroc(y,x,params) | |
% Reused and adapted from the Matlab code of LIBSVM. | |
% plotroc draws the recevier operating characteristic(ROC) curve. | |
% | |
% auc = plotroc(training_label, training_instance [, libsvm_options -v cv_fold]) | |
% Use cross-validation on training data to get decision values and plot ROC curve. | |
% | |
% auc = plotroc(testing_label, testing_instance, model) | |
% Use the given model to predict testing data and obtain decision values | |
% for ROC | |
% | |
% Example: | |
% | |
% load('heart_scale.mat'); | |
% plotroc(heart_scale_label, heart_scale_inst,'-v 5'); | |
% | |
% [y,x] = libsvmread('heart_scale'); | |
% model = svmtrain(y,x); | |
% plotroc(y,x,model); | |
rand('state',0); % reset random seed | |
if nargin < 2 | |
help plotroc | |
return | |
elseif isempty(y) | isempty(x) | |
error('Input data is empty'); | |
elseif sum(y == 1) + sum(y == -1) ~= length(y) | |
error('ROC is only applicable to binary classes with labels 1, -1'); % check the trainig_file is binary | |
elseif exist('params') && ~ischar(params) | |
model = params; | |
[predict_label,mse,deci] = libsvmpredict(y,x,model); % the procedure for predicting | |
[auROC auPRC] = roc_curve(deci*model.Label(1),y); | |
elseif exist('params') && strcmp(params, 'personal') | |
[auROC auPRC] = roc_curve(x, y); | |
[prc_stack_x, prc_stack_y, prc_thre, auPRC] = perfcurve(y, x, 1, 'xCrit', 'reca', 'yCrit', 'prec'); | |
%% This condition takes care of scenario when | |
%% I predict the values and wish to get the auROC. | |
%% 'y' in that case holds the ground_truth labels and | |
%% 'x' holds the SVM predicted values. | |
else | |
if ~exist('params') | |
params = []; | |
end | |
[param,fold] = proc_argv(params); % specify each parameter | |
if fold <= 1 | |
error('The number of folds must be greater than 1'); | |
else | |
[deci,label_y] = get_cv_deci(y,x,param,fold); % get the value of decision and label after cross-calidation | |
[auROC auPRC] = roc_curve(deci,label_y); % plot ROC curve | |
end | |
end | |
end | |
function [resu,fold] = proc_argv(params) | |
resu=params; | |
fold=5; | |
if ~isempty(params) && ~isempty(regexp(params,'-v')) | |
[fold_val,fold_start,fold_end] = regexp(params,'-v\s+\d+','match','start','end'); | |
if ~isempty(fold_val) | |
[temp1,fold] = strread([fold_val{:}],'%s %u'); | |
resu([fold_start:fold_end]) = []; | |
else | |
error('Number of CV folds must be specified by "-v cv_fold"'); | |
end | |
end | |
end | |
function [deci,label_y] = get_cv_deci(prob_y,prob_x,param,nr_fold) | |
l=length(prob_y); | |
deci = ones(l,1); | |
label_y = ones(l,1); | |
rand_ind = randperm(l); | |
for i=1:nr_fold % Cross training : folding | |
test_ind=rand_ind([floor((i-1)*l/nr_fold)+1:floor(i*l/nr_fold)]'); | |
train_ind = [1:l]'; | |
train_ind(test_ind) = []; | |
model = libsvmtrain(prob_y(train_ind),prob_x(train_ind,:),param); | |
[predict_label,mse,subdeci] = libsvmpredict(prob_y(test_ind),prob_x(test_ind,:),model); | |
deci(test_ind) = subdeci.*model.Label(1); | |
label_y(test_ind) = prob_y(test_ind); | |
end | |
end | |
function [auROC, auPRC] = roc_curve(deci,label_y) | |
[val,ind] = sort(deci,'descend'); | |
roc_y = label_y(ind); | |
stack_x = cumsum(roc_y == -1)/sum(roc_y == -1); | |
stack_y = cumsum(roc_y == 1)/sum(roc_y == 1); | |
auROC = sum((stack_x(2:length(roc_y),1)-stack_x(1:length(roc_y)-1,1)).*stack_y(2:length(roc_y),1)); | |
%Comment the above lines if using perfcurve of statistics toolbox | |
[stack_x,stack_y,thre,auPRC]=perfcurve(label_y , deci, 1, 'xCrit', 'reca', 'yCrit', 'prec'); | |
%plot(stack_x,stack_y); | |
%xlabel('False Positive Rate'); | |
%ylabel('True Positive Rate'); | |
%title(['ROC curve of (AUC = ' num2str(auc) ' )']); | |
end |