You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
142 lines
4.4 KiB
Matlab
142 lines
4.4 KiB
Matlab
function calc_classification_error(V,D,classDict,data,class,show);
|
|
GP=global_parameters;
|
|
numPatterns=length(class);
|
|
numClasses=max(class);
|
|
[n,m]=size(V);
|
|
disp(['testing performance on data set with ',int2str(numPatterns),' elements, in ',int2str(numClasses),' classes']);
|
|
|
|
%TEST classification performance by comparing the class predicted by the output of the
|
|
%network with the true class defined in the dataset, for all patterns in the test set
|
|
maxk=min(15,n);
|
|
eAll=zeros(1,numPatterns);
|
|
sAll=zeros(1,numPatterns);
|
|
errCountkNN=zeros(1,maxk);
|
|
errCountMax=0;
|
|
errCountSum=0;
|
|
errCountErr=0;
|
|
|
|
extraStats=0;
|
|
if extraStats
|
|
errSparsity=[];
|
|
repSparsity=[];
|
|
errError=[];
|
|
repError=[];
|
|
errRatio=[];
|
|
repRatio=[];
|
|
end
|
|
execTime=0;
|
|
|
|
for pattern=1:numPatterns
|
|
if rem(pattern,100)==0, fprintf(1,'.%i.',pattern); end
|
|
|
|
%calculate coefficients for/response to input image
|
|
x=data(:,pattern);
|
|
|
|
[y,e,s,nmse,execTime,sTrace]=calc_sparse_representation(V,x,classDict,execTime);
|
|
eAll(pattern)=mean(nmse); %store reconstruction error to calc mean later
|
|
sAll(pattern)=s; %store representation sparsity to calc mean later
|
|
classExpected=class(pattern);
|
|
|
|
%classification based maximum response
|
|
[val,ind]=max(y);
|
|
classGenerated=classDict(ind);
|
|
if classExpected~=classGenerated, errCountMax=errCountMax+1; end
|
|
|
|
%classification based k-maximum responses
|
|
[val,ind]=sort(y,'descend');
|
|
for k=1:maxk;
|
|
counts=histc(classDict(ind(1:k)),[0.5:1:numClasses+0.5]);
|
|
[val,classGenerated]=max(counts);
|
|
if classExpected~=classGenerated, errCountkNN(k)=errCountkNN(k)+1; end
|
|
end
|
|
|
|
%classification based on sum of responses
|
|
z=D*y;
|
|
[val,classGenerated]=max(z);
|
|
vals=sort(z,'descend');
|
|
if classExpected~=classGenerated, errCountSum=errCountSum+1; end
|
|
|
|
%classification based on reconstruction error
|
|
switch GP.network
|
|
case 'subnets'
|
|
%classify based on which sub-dictionary has lowest reconstruction error
|
|
[val,classGenerated]=min(nmse);
|
|
if classExpected~=classGenerated, errCountErr=errCountErr+1; end
|
|
case 'single'
|
|
%calc reconstruction error for nodes of a single class:
|
|
for c=1:numClasses
|
|
ind=logical(classDict==c);
|
|
yclass=zeros(n,1);
|
|
yclass(ind)=y(ind); %y values for one class only
|
|
err(c)=norm(x-(V'*yclass),2);
|
|
end
|
|
%classify based on which sub-dictionary has lowest reconstruction error
|
|
[val,classGenerated]=min(err);
|
|
if classExpected~=classGenerated, errCountErr=errCountErr+1; end
|
|
end
|
|
|
|
if extraStats
|
|
if classExpected~=classGenerated
|
|
errRatio=[errRatio,vals(1)/vals(2)];
|
|
errSparsity=[errSparsity,s];
|
|
errError=[errError,nmse];
|
|
else
|
|
repRatio=[errRatio,vals(1)/vals(2)];
|
|
repSparsity=[repSparsity,s];
|
|
repError=[repError,nmse];
|
|
end
|
|
end
|
|
|
|
%plot results
|
|
if show
|
|
figure(pattern+show-1)
|
|
plot_result(x,y,sTrace,V,classDict,classExpected);
|
|
end
|
|
end
|
|
disp(' ');
|
|
if extraStats
|
|
%report sparsity and error for correctly and incorrectly classified patterns
|
|
disp([' errSparsity=',num2str(nanmean(errSparsity)),' repSparsity=',num2str(nanmean(repSparsity)),' / errError=',num2str(nanmean(errError)),' repError=',num2str(nanmean(repError)),' / errRatio=',num2str(nanmean(errRatio)),' repRatio=',num2str(nanmean(repRatio))]);
|
|
end
|
|
|
|
%calculate percentage classification errors
|
|
errkNN=100.*errCountkNN./numPatterns
|
|
errMax=100*errCountMax/numPatterns;
|
|
errSum=100*errCountSum/numPatterns;
|
|
errErr=100*errCountErr/numPatterns;
|
|
%report results
|
|
disp(' errMax= errSum= errErr= execTime= Hoyer= NMSE=');
|
|
disp([num2str(errMax),' ',num2str(errSum),' ',num2str(errErr),' ',num2str(execTime),' ',num2str(nanmean(sAll)),' ',num2str(nanmean(eAll))]);
|
|
|
|
|
|
|
|
function plot_result(x,y,sTrace,W,class,classExpected)
|
|
%plot input image, convergence of dim competition, and most active basis vectors
|
|
clf
|
|
GP=global_parameters;
|
|
[n,m]=size(W);
|
|
%p=sqrt(m);
|
|
|
|
%plot change in sparsity over time, and final responses
|
|
numActive=length(find(y>0.1*max(y)))
|
|
numToPlot=min(max(3,numActive),6);
|
|
maxsubplot(2,numToPlot,numToPlot+3:2*numToPlot,0.15),plot(y'),axis('tight')
|
|
maxsubplot(2,numToPlot,numToPlot+2,0.15); plot(sTrace','r','LineWidth',2);
|
|
|
|
%plot input image
|
|
maxsubplot(2,numToPlot,numToPlot+1,0.15),
|
|
plot_weights(x);
|
|
title([num2str(classExpected-1)]);
|
|
|
|
%plot most active basis vectors
|
|
[val,ind]=sort(y,1,'descend');
|
|
for i=1:min(numActive,6)
|
|
maxsubplot(2,numToPlot,i,0.15)
|
|
plot_weights(W(ind(i),:));
|
|
title([num2str(class(ind(i))-1),': ',num2str(val(i))]);
|
|
end
|
|
drawnow;
|
|
|
|
|
|
|