You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

127 lines
4.1 KiB
Matlab

function classify_images(task)
%Apply PC/BC-DIM to classifying images
if nargin<1 || isempty(task)
task='USPS';%'YALE'; %
end
onoff=0;
%load data
[data,class,inTrain,inTest,imDims,numClasses]=images_load_dataset(task);
if onoff
%subtract mean from each image
data=bsxfun(@minus,data,mean(data')');
%split into ON and OFF channels
on=data; on(on<0)=0;
off=-data; off(off<0)=0;
data=[on,off];
else
%scale each image to range between 0 and 1
data=bsxfun(@minus,data,min(data')');
data=bsxfun(@rdivide,data,max(data')');
end
data=bsxfun(@rdivide,data,max(1e-6,sum(abs(data),2)));
%define weights: use each training image as an exemplar
W=data(inTrain,:);
%additional inputs represent class
D=zeros(length(inTrain),numClasses);
for c=1:numClasses
%make all weights equal to 1 from basis vectors representing the same class
D(:,c)=class(inTrain)==c;
end
W=[W,max(max(abs(W))).*D];
%normalise weights
%W=bsxfun(@rdivide,W,max(1e-6,sum(abs(W),2)));
W=bsxfun(@rdivide,W,max(1e-6,max(abs(W),[],2)));
[n,m]=size(W)
%define test cases
X(:,1)=[data(inTest(130),:)';zeros(numClasses,1)]; %certain sparse
X(:,2)=[data(inTest(140),:)';zeros(numClasses,1)]; %uncertain
X(:,3)=X(:,2).*0; X(m-6,3)=1;%blank
%present exemplar test cases to network and record results
for k=1:size(X,2)
x=X(:,k);
[y,e,r]=dim_activation_batch(W,x);
figure(k),clf
images_plot_results(x,r,y,imDims,W,onoff);
print(gcf, '-dpdf', ['classify_',task,int2str(k),'.pdf']);
end
%present all test cases to network and calculate classification error
X=[data(inTest,:)';zeros(numClasses,length(inTest))];
[Y,E,R]=dim_activation_batch(W,X);
[poo,classPredicted]=max(R(m-numClasses+1:m,:));
percent_error=100*sum(classPredicted~=class(inTest))./length(inTest)
function images_plot_results(x,r,y,imSize,W,onoff)
dataIndeces=[1:prod(imSize)];
if onoff
xIm=x(dataIndeces)-x(dataIndeces+max(dataIndeces));
rIm=r(dataIndeces)-r(dataIndeces+max(dataIndeces));
xCl=x(1+2*max(dataIndeces):end);
rCl=r(1+2*max(dataIndeces):end);
else
xIm=x(dataIndeces);
rIm=r(dataIndeces);
xCl=x(1+max(dataIndeces):end);
rCl=r(1+max(dataIndeces):end);
end
top=1.05;
axes('Position',[0.12,0.05,0.37,0.24]),
imagesc(reshape(xIm,imSize)'),axis('equal','tight','on')
set(gca,'YTick',[],'XTick',[],'FontSize',18);
text(0.04,1,'x_a','Units','normalized','color','k','FontSize',18,'FontWeight','bold','VerticalAlignment','top')
axes('Position',[0.51,0.05,0.37,0.24]),
bar(xCl,'k'),axis([0.5,length(xCl)+0.5,0,top])
set(gca,'FontSize',18);
if length(xCl)==10, set(gca,'XTickLabel',[0:9]); end
text(0.02,1,'x_b','Units','normalized','color','k','FontSize',18,'FontWeight','bold','VerticalAlignment','top')
top=double(max(0.125,max(y)));
offset=0.085;
axes('Position',[0.12+offset,0.38,0.76-offset,0.24]),
plot(y,'r'),axis([0.5,length(y)+0.5,0,top])
xt=get(gca,'XTick');
if max(xt)>0.95.*length(y), xt=xt(1:end-1); end
set(gca,'XTick',xt,'YTick',[0:0.1:1],'FontSize',18);
hold on
text(0.01,1,'y','Units','normalized','color','r','FontSize',18,'FontWeight','bold','VerticalAlignment','top');
[m,ind]=sort(y.*(1+0.001.*rand(size(y))),'descend');
numToLabel=min(25,length(find(m>0.25*m(1))));
for i=1:numToLabel,
axes('Position',[offset+0.12+(0.76-offset).*(ind(i)-1)./length(y)-0.025,0.385+0.24*min(1,m(i)/top),0.05,0.05])
if onoff
imagesc(reshape(W(ind(i),dataIndeces)-W(ind(i),dataIndeces+max(dataIndeces)),imSize)'),
else
imagesc(reshape(W(ind(i),dataIndeces),imSize)');
end
axis('equal','tight','off')
end
top=1.05;
axes('Position',[0.12,0.71,0.37,0.24]),
imagesc(reshape(rIm,imSize)'),axis('equal','tight','on')
set(gca,'YTick',[],'XTick',[],'FontSize',18);
text(0.04,1,'r_a','Units','normalized','color',[0,0.7,0],'FontSize',18,'FontWeight','bold','VerticalAlignment','top')
axes('Position',[0.51,0.71,0.37,0.24]),
bar(rCl,'FaceColor',[0,0.7,0]),axis([0.5,length(rCl)+0.5,0,top])
set(gca,'FontSize',18);
if length(xCl)==10, set(gca,'XTickLabel',[0:9]); end
text(0.02,1,'r_b','Units','normalized','color',[0,0.7,0],'FontSize',18,'FontWeight','bold','VerticalAlignment','top')
cmap=colormap('gray');%cmap=1-cmap;colormap(cmap);
set(gcf,'PaperSize',[18 16],'PaperPosition',[0 0.5 18 15],'PaperOrientation','Portrait');