You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
249 lines
7.7 KiB
Matlab
249 lines
7.7 KiB
Matlab
function [classification_error,outputData,class,inTrain,inTest,numClasses]=classify_images(task,sigmaLGN,data,class,inTrain,inTest,imDims,numClasses,plotPolarity)
|
|
%Apply 2-stage hierarchical PC/BC-DIM to classifying images
|
|
%Use clustering to learn weights for first processing stage
|
|
%Define second processing stage weights as the summed 1st layer response to all members of each category
|
|
|
|
if nargin<1 || isempty(task)
|
|
task='USPS';% 'MNIST'; % 'YALE4';%
|
|
end
|
|
|
|
%set parameters
|
|
if nargin<2 || isempty(sigmaLGN)
|
|
sigmaLGN=16
|
|
end
|
|
patchClustering='agglom';%'exemplars'; %'dim'; %
|
|
ONOFF=1;
|
|
figoff=0;
|
|
|
|
%LOAD DATA
|
|
if nargin<3
|
|
[data,class,inTrain,inTest,imDims,numClasses,plotPolarity]=load_classify_image_dataset(task);
|
|
end
|
|
|
|
%CLUSTER TRAINING IMAGES TO FORM DICTIONARY
|
|
Xtrain=data(inTrain,:)';
|
|
switch patchClustering
|
|
case 'dim'
|
|
n=5000;
|
|
skipTraining=0;
|
|
filename=['classify_images_dim_weights',task,'_n',int2str(n),'.mat'];
|
|
if skipTraining
|
|
load(filename);
|
|
else
|
|
beta=0.1
|
|
alpha=1
|
|
batchLen=1 %length(inTrain) %number of training patterns used in each learning batch
|
|
cycs=fix(40*max(n,length(inTrain))/sqrt(batchLen)) %number of training cycles
|
|
show=fix(1e4/sqrt(batchLen)); %how often to plot receptive field data
|
|
[W,V]=weight_initialisation_random(n,size(Xtrain,1));
|
|
max_y=0;sum_nse=0;sum_sparsity=0;
|
|
for cyc=1:cycs
|
|
%choose a batch of input stimuli to train on
|
|
order=randperm(length(inTrain));
|
|
%calculate responses to training data
|
|
[Y,E]=dim_activation(W,Xtrain(:,order(1:batchLen)),V);
|
|
%update weights
|
|
[W,V]=dim_learn(W,V,Y,E,beta,alpha);
|
|
|
|
%record data
|
|
max_y=max([max_y,max(max(Y))]);
|
|
sum_nse=sum_nse+measure_nse(Xtrain(:,order(1:batchLen)),V'*Y)./batchLen;
|
|
sum_sparsity=sum_sparsity+mean(measure_sparsity_hoyer(Y));
|
|
if rem(cyc,show)==0 || cyc==cycs
|
|
fprintf(1,'.%i.',cyc);
|
|
disp([' ymax=',num2str(max_y,3),...
|
|
' NSE=',num2str(sum_nse./show,3),...
|
|
' Sparsity=',num2str(sum_sparsity./show,3)]);
|
|
max_y=0;sum_nse=0;sum_sparsity=0;
|
|
end
|
|
|
|
end
|
|
save(filename, 'W', 'V');
|
|
end
|
|
|
|
case 'exemplars'
|
|
V=Xtrain';
|
|
n=length(inTrain);
|
|
|
|
case 'agglomAllPatches'
|
|
numClusterMembersReqd=0
|
|
similarityThres=0.85
|
|
|
|
clusterIndex = clusterdata(Xtrain','cutoff',1-similarityThres,'criterion','distance','linkage','complete','distance',@distance_measure);%note clustering is based on distance, not similarity, so need to use 1-similarity threshold
|
|
|
|
numClusters=max(clusterIndex);
|
|
k=0;
|
|
for i=1:numClusters
|
|
clustInd=find(clusterIndex==i);
|
|
if length(clustInd)>numClusterMembersReqd
|
|
k=k+1;
|
|
V(k,:)=mean(Xtrain(:,clustInd),2); %each cluster is mean of all members
|
|
end
|
|
end
|
|
n=k;
|
|
|
|
case 'agglom'
|
|
numClusterMembersReqd=0
|
|
similarityThres=0.9
|
|
|
|
k=0;
|
|
for c=1:numClasses
|
|
classInd=find(class(inTrain)==c);
|
|
clusterIndex = clusterdata(Xtrain(:,classInd)','cutoff',1-similarityThres,'criterion','distance','linkage','complete','distance',@distance_measure);%note clustering is based on distance, not similarity, so need to use 1-similarity threshold
|
|
|
|
numClusters(c)=max(clusterIndex);
|
|
for i=1:numClusters(c)
|
|
clustInd=find(clusterIndex==i);
|
|
if length(clustInd)>numClusterMembersReqd
|
|
k=k+1;
|
|
V(k,:)=mean(Xtrain(:,classInd(clustInd)),2); %each cluster is mean of all members
|
|
end
|
|
end
|
|
end
|
|
numClusters
|
|
n=k;
|
|
end
|
|
toPlot=randperm(n);
|
|
toPlot=toPlot([1:min(48,n)]);
|
|
if exist('W');
|
|
figured(figoff+1),clf,
|
|
plot_weights(W,toPlot,imDims,plotPolarity);
|
|
end
|
|
figured(figoff+2),clf,
|
|
plot_weights(V,toPlot,imDims,plotPolarity);
|
|
drawnow
|
|
print('-dpdf',[task,'_dictionary.pdf'])
|
|
|
|
n=size(V,1)
|
|
|
|
|
|
if ONOFF
|
|
%CONVERT DICTIONARY ENTRIES TO ON/OFF WEIGHTS
|
|
V=imnorm_batch(V',imDims,sigmaLGN)';
|
|
|
|
%preprocess input images
|
|
filename=['classify_images_imnormed_data_',task,'_sigmaLGN',num2str(sigmaLGN),'.mat'];
|
|
if exist(filename,'file')==2
|
|
load(filename);
|
|
else
|
|
data=imnorm_batch(data',imDims,sigmaLGN)';
|
|
save(filename, 'data');
|
|
end
|
|
imDims=[imDims,2];
|
|
end
|
|
%normalise weights
|
|
W=bsxfun(@rdivide,V,max(1e-6,sum(V,2)));
|
|
V=bsxfun(@rdivide,V,max(1e-6,max(V,[],2)));
|
|
%recale each input to range from 0 to 1
|
|
data=bsxfun(@minus,data,min(data')');
|
|
data=bsxfun(@rdivide,data,max(data')');
|
|
|
|
figured(figoff+3),clf,
|
|
plot_weights(V,toPlot,imDims,plotPolarity);
|
|
drawnow
|
|
print('-dpdf',[task,'_stage1_weights.pdf'])
|
|
figured(figoff+4),clf,
|
|
plot_weights(data,toPlot,imDims,plotPolarity);
|
|
drawnow
|
|
|
|
%MATCH DICTIONARY PATCHES TO TRAINING IMAGES using DIM
|
|
outputData=dim_activation(W,data',V)';
|
|
%count the number of votes corresponding to each cluster
|
|
threshold=1e-3;
|
|
outputData(outputData<threshold)=0;
|
|
for c=1:numClasses
|
|
ind=class(inTrain)==c;
|
|
Wvotes(c,:)=sum(outputData(inTrain(ind),:));
|
|
end
|
|
Wvotes=bsxfun(@rdivide,Wvotes,max(1e-6,sum(Wvotes,1))); %would have no effect when using convolution
|
|
Vvotes=bsxfun(@rdivide,Wvotes,max(1e-6,max(Wvotes,[],2))); %std dim as used in ism
|
|
%Vvotes=bsxfun(@rdivide,Wvotes,max(1e-6,max(Wvotes,[],1))); %slightly worse
|
|
figured(figoff+5),clf, imagesc(Wvotes); colorbar
|
|
figured(figoff+6),clf, imagesc(Vvotes); colorbar
|
|
|
|
|
|
%MATCH DICTIONARY PATCHES TO TESTING IMAGES using DIM
|
|
XclassifierTest=outputData(inTest,:)';
|
|
[Yclassifier]=dim_activation(Wvotes,XclassifierTest,Vvotes);
|
|
threshold=1e-3;
|
|
Yclassifier(Yclassifier<threshold)=0;
|
|
%calculate classification error
|
|
[~,classPredicted]=max(Yclassifier);
|
|
classification_error=100*sum(classPredicted~=class(inTest))./length(inTest)
|
|
|
|
%plot example responses for a few test images
|
|
toPlot=randperm(length(inTest));
|
|
toPlot=toPlot(1:10);
|
|
fig=figoff+10;
|
|
for k=toPlot
|
|
fig=fig+1;
|
|
figured(fig), clf
|
|
plot_network(data(inTest(k),:),outputData(inTest(k),:),Yclassifier(:,k),imDims,V,plotPolarity);
|
|
print('-dpdf',[task,'_response_example',int2str(fig),'.pdf'])
|
|
end
|
|
%show which images from original test dataset were mis-classified
|
|
[data,class,inTrain,inTest,imDims]=load_classify_image_dataset(task);
|
|
figured(figoff+21),clf,plot_misclassified(classPredicted,class(inTest),data(inTest,:),imDims(1:2),plotPolarity,1);
|
|
print('-dpdf',[task,'_misclassified.pdf'])
|
|
|
|
|
|
|
|
function plot_network(input,Y1,Y2,imDims,Wvis,plotPolarity)
|
|
|
|
subplot(3,5,13);
|
|
if length(imDims)==3
|
|
plot_image(diff(reshape(input,imDims),1,3)');
|
|
elseif length(imDims)==2
|
|
plot_image(reshape(input,imDims)');
|
|
end
|
|
|
|
cmap=colormap('gray');if plotPolarity>0, cmap=1-cmap;colormap(cmap); end
|
|
|
|
axProp=subplot(3,1,2);
|
|
top=max(0.01,1.05.*max(Y1));
|
|
width=max(1,length(Y1)/2000);
|
|
bar([1:length(Y1)]-1,Y1,width,'FaceColor','r','EdgeColor','r','LineWidth',1);
|
|
axis([0.5,length(Y1)+0.5,0,top])
|
|
|
|
ax=axProp.Position;
|
|
[m,ind]=sort(Y1.*(1+0.001.*rand(size(Y1))),'descend');
|
|
numToLabel=min(25,length(find(m>0.25*m(1))));
|
|
for i=1:numToLabel,
|
|
axes('Position',[ax(1)+(ax(3).*(ind(i)-1)./length(Y1))-0.025,ax(2)+ax(4)*min(1,m(i)/top),0.06,0.06])
|
|
if length(imDims)==3
|
|
plot_image(diff(reshape(Wvis(ind(i),:),imDims),1,3)');
|
|
elseif length(imDims)==2
|
|
plot_image(reshape(Wvis(ind(i),:),imDims)');
|
|
end
|
|
end
|
|
|
|
subplot(3,3,2);
|
|
top=max(0.01,1.05.*max(Y2));
|
|
width=max(1,length(Y2)/2000);
|
|
bar([1:length(Y2)]-1,Y2,width,'FaceColor','r','EdgeColor','r');
|
|
axis([-0.5,length(Y2)-0.5,0,top])
|
|
|
|
set(gcf,'PaperSize',[18 10],'PaperPosition',[0 0 18 10],'PaperOrientation','Portrait');
|
|
drawnow;
|
|
|
|
|
|
|
|
|
|
function [X]=imnorm_batch(I,imDims,sigma,gain,leavepadded)
|
|
if nargin<3, sigma=[]; end
|
|
if nargin<4, gain=[]; end
|
|
if nargin<5, leavepadded=[]; end
|
|
|
|
[a,batchLen]=size(I);
|
|
I=single(I);
|
|
|
|
X=zeros(2*a,batchLen,'single');
|
|
for t=1:batchLen
|
|
It=reshape(I(:,t),imDims);
|
|
|
|
[~,~,Xon,Xoff]=imnorm(It,sigma,gain,leavepadded);
|
|
|
|
X(:,t)=[Xon(:);Xoff(:)];
|
|
end
|
|
|