You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
127 lines
4.1 KiB
Matlab
127 lines
4.1 KiB
Matlab
function classify_images(task)
|
|
%Apply PC/BC-DIM to classifying images
|
|
if nargin<1 || isempty(task)
|
|
task='USPS';%'YALE'; %
|
|
end
|
|
onoff=0;
|
|
|
|
%load data
|
|
[data,class,inTrain,inTest,imDims,numClasses]=images_load_dataset(task);
|
|
|
|
if onoff
|
|
%subtract mean from each image
|
|
data=bsxfun(@minus,data,mean(data')');
|
|
%split into ON and OFF channels
|
|
on=data; on(on<0)=0;
|
|
off=-data; off(off<0)=0;
|
|
data=[on,off];
|
|
else
|
|
%scale each image to range between 0 and 1
|
|
data=bsxfun(@minus,data,min(data')');
|
|
data=bsxfun(@rdivide,data,max(data')');
|
|
end
|
|
data=bsxfun(@rdivide,data,max(1e-6,sum(abs(data),2)));
|
|
|
|
%define weights: use each training image as an exemplar
|
|
W=data(inTrain,:);
|
|
%additional inputs represent class
|
|
D=zeros(length(inTrain),numClasses);
|
|
for c=1:numClasses
|
|
%make all weights equal to 1 from basis vectors representing the same class
|
|
D(:,c)=class(inTrain)==c;
|
|
end
|
|
W=[W,max(max(abs(W))).*D];
|
|
|
|
%normalise weights
|
|
%W=bsxfun(@rdivide,W,max(1e-6,sum(abs(W),2)));
|
|
W=bsxfun(@rdivide,W,max(1e-6,max(abs(W),[],2)));
|
|
[n,m]=size(W)
|
|
|
|
|
|
%define test cases
|
|
X(:,1)=[data(inTest(130),:)';zeros(numClasses,1)]; %certain sparse
|
|
X(:,2)=[data(inTest(140),:)';zeros(numClasses,1)]; %uncertain
|
|
X(:,3)=X(:,2).*0; X(m-6,3)=1;%blank
|
|
|
|
%present exemplar test cases to network and record results
|
|
for k=1:size(X,2)
|
|
x=X(:,k);
|
|
[y,e,r]=dim_activation_batch(W,x);
|
|
|
|
figure(k),clf
|
|
images_plot_results(x,r,y,imDims,W,onoff);
|
|
print(gcf, '-dpdf', ['classify_',task,int2str(k),'.pdf']);
|
|
end
|
|
|
|
%present all test cases to network and calculate classification error
|
|
X=[data(inTest,:)';zeros(numClasses,length(inTest))];
|
|
[Y,E,R]=dim_activation_batch(W,X);
|
|
[poo,classPredicted]=max(R(m-numClasses+1:m,:));
|
|
percent_error=100*sum(classPredicted~=class(inTest))./length(inTest)
|
|
|
|
|
|
|
|
|
|
function images_plot_results(x,r,y,imSize,W,onoff)
|
|
dataIndeces=[1:prod(imSize)];
|
|
if onoff
|
|
xIm=x(dataIndeces)-x(dataIndeces+max(dataIndeces));
|
|
rIm=r(dataIndeces)-r(dataIndeces+max(dataIndeces));
|
|
xCl=x(1+2*max(dataIndeces):end);
|
|
rCl=r(1+2*max(dataIndeces):end);
|
|
else
|
|
xIm=x(dataIndeces);
|
|
rIm=r(dataIndeces);
|
|
xCl=x(1+max(dataIndeces):end);
|
|
rCl=r(1+max(dataIndeces):end);
|
|
end
|
|
|
|
top=1.05;
|
|
axes('Position',[0.12,0.05,0.37,0.24]),
|
|
imagesc(reshape(xIm,imSize)'),axis('equal','tight','on')
|
|
set(gca,'YTick',[],'XTick',[],'FontSize',18);
|
|
text(0.04,1,'x_a','Units','normalized','color','k','FontSize',18,'FontWeight','bold','VerticalAlignment','top')
|
|
|
|
axes('Position',[0.51,0.05,0.37,0.24]),
|
|
bar(xCl,'k'),axis([0.5,length(xCl)+0.5,0,top])
|
|
set(gca,'FontSize',18);
|
|
if length(xCl)==10, set(gca,'XTickLabel',[0:9]); end
|
|
text(0.02,1,'x_b','Units','normalized','color','k','FontSize',18,'FontWeight','bold','VerticalAlignment','top')
|
|
|
|
top=double(max(0.125,max(y)));
|
|
offset=0.085;
|
|
axes('Position',[0.12+offset,0.38,0.76-offset,0.24]),
|
|
plot(y,'r'),axis([0.5,length(y)+0.5,0,top])
|
|
xt=get(gca,'XTick');
|
|
if max(xt)>0.95.*length(y), xt=xt(1:end-1); end
|
|
set(gca,'XTick',xt,'YTick',[0:0.1:1],'FontSize',18);
|
|
hold on
|
|
text(0.01,1,'y','Units','normalized','color','r','FontSize',18,'FontWeight','bold','VerticalAlignment','top');
|
|
[m,ind]=sort(y.*(1+0.001.*rand(size(y))),'descend');
|
|
numToLabel=min(25,length(find(m>0.25*m(1))));
|
|
for i=1:numToLabel,
|
|
axes('Position',[offset+0.12+(0.76-offset).*(ind(i)-1)./length(y)-0.025,0.385+0.24*min(1,m(i)/top),0.05,0.05])
|
|
if onoff
|
|
imagesc(reshape(W(ind(i),dataIndeces)-W(ind(i),dataIndeces+max(dataIndeces)),imSize)'),
|
|
else
|
|
imagesc(reshape(W(ind(i),dataIndeces),imSize)');
|
|
end
|
|
axis('equal','tight','off')
|
|
end
|
|
|
|
top=1.05;
|
|
axes('Position',[0.12,0.71,0.37,0.24]),
|
|
imagesc(reshape(rIm,imSize)'),axis('equal','tight','on')
|
|
set(gca,'YTick',[],'XTick',[],'FontSize',18);
|
|
text(0.04,1,'r_a','Units','normalized','color',[0,0.7,0],'FontSize',18,'FontWeight','bold','VerticalAlignment','top')
|
|
|
|
axes('Position',[0.51,0.71,0.37,0.24]),
|
|
bar(rCl,'FaceColor',[0,0.7,0]),axis([0.5,length(rCl)+0.5,0,top])
|
|
set(gca,'FontSize',18);
|
|
if length(xCl)==10, set(gca,'XTickLabel',[0:9]); end
|
|
text(0.02,1,'r_b','Units','normalized','color',[0,0.7,0],'FontSize',18,'FontWeight','bold','VerticalAlignment','top')
|
|
|
|
cmap=colormap('gray');%cmap=1-cmap;colormap(cmap);
|
|
|
|
set(gcf,'PaperSize',[18 16],'PaperPosition',[0 0.5 18 15],'PaperOrientation','Portrait');
|