|
|
function [V,classDict]=define_dictionary_weights(data,class)
|
|
|
[m,numPatterns]=size(data);
|
|
|
numClasses=max(class);
|
|
|
GP=global_parameters;
|
|
|
|
|
|
switch GP.dictionary
|
|
|
case 'all'
|
|
|
%DEFINE THE DICTIONARY
|
|
|
%use all data in the training set as elements of the dictionary
|
|
|
V=data';
|
|
|
W=define_pcbc_feedforward_weights(V);
|
|
|
%RECORD THE CLASS OF EACH DICTIONARY ELEMENT
|
|
|
classDict=class;
|
|
|
|
|
|
case 'learn'
|
|
|
%Learn the dictionary using the ILS-DLA/MOD method Engan, Skretting and Husøy
|
|
|
%(2007) Digital Signal Processing 17 (2007) 32–49. See also:
|
|
|
%http://www.ux.uis.no/~karlsk/dle/
|
|
|
|
|
|
nodesPerClass=100 %size of class specific sub-dictionaries (all equal for convienience)
|
|
|
epochs=25 %number of passes through the training data set
|
|
|
|
|
|
%RECORD THE CLASS OF EACH DICTIONARY ELEMENT
|
|
|
classDict=[];
|
|
|
for c=1:numClasses
|
|
|
classDict(1+(c-1)*nodesPerClass:c*nodesPerClass)=c;
|
|
|
end
|
|
|
|
|
|
%DEFINE THE DICTIONARY
|
|
|
|
|
|
%learn separate dictionary for each class
|
|
|
fprintf(1,'learning dictionary entries: ');
|
|
|
for c=1:numClasses
|
|
|
fprintf(1,'.%i.',c);
|
|
|
|
|
|
%extract data for one class to use as training data
|
|
|
ind=logical(class==c);
|
|
|
dataC=data(:,ind);
|
|
|
numPatternsC=size(dataC,2);
|
|
|
|
|
|
%define initial weights for one class
|
|
|
VC=dataC(:,1:nodesPerClass)';
|
|
|
VC=norm_dictionary(VC);
|
|
|
|
|
|
%learn weights for each class-specific sub-dictionary in turn
|
|
|
for k=1:ceil(epochs)
|
|
|
%calculate sparse representations using current dictionary
|
|
|
clear YC;
|
|
|
repeat=1;
|
|
|
while repeat
|
|
|
for l=1:numPatternsC
|
|
|
[YC(:,l),e,s,nmse(l)]=calc_sparse_representation(VC,dataC(:,l),ones(1,nodesPerClass),0);
|
|
|
end
|
|
|
[val,ind]=min(max(abs(YC),[],2));
|
|
|
if val==0
|
|
|
%one of the dictionary elements fails to respond to any inputs: replace with training vector with the higest reconstruction error
|
|
|
disp('replacing dictionary element'), k
|
|
|
[val,new]=max(nmse);
|
|
|
VC(ind,:)=dataC(:,new)';
|
|
|
else
|
|
|
repeat=0;
|
|
|
end
|
|
|
end
|
|
|
%update dictionary
|
|
|
VC = ((dataC*YC')*inv(YC*YC'))';
|
|
|
%VC = ((dataC*YC')/(YC*YC'))';
|
|
|
|
|
|
VC(VC<0)=0; %constrain dictionary to be non-negative
|
|
|
VC=norm_dictionary(VC);
|
|
|
|
|
|
%display useful data about current dictionary
|
|
|
recon=VC'*YC;
|
|
|
nmse=(sum((dataC(:)-recon(:)).^2))./(1e-9+sum(dataC(:).^2));
|
|
|
sparsity=measure_sparsity_hoyer(YC);
|
|
|
disp([ ' mean NMSE=',num2str(nmse),...
|
|
|
' mean Sparsity=',num2str(sparsity)]);
|
|
|
end
|
|
|
%insert learned weights back into overall dictionary
|
|
|
ind=logical(classDict==c);
|
|
|
V(ind,:)=VC;
|
|
|
end
|
|
|
|
|
|
case 'selective'
|
|
|
%DEFINE THE DICTIONARY
|
|
|
%start with a dictionary containing all elements of the training set...
|
|
|
V=data';
|
|
|
%... then use the normalised cross-correlation to measure the similarity of
|
|
|
%dictionary elements...
|
|
|
Vnorm=norm_dictionary(V,2);
|
|
|
NCC=Vnorm*Vnorm';
|
|
|
indDelete=[];
|
|
|
Vnew=[];
|
|
|
classNew=[];
|
|
|
[n,m]=size(V);
|
|
|
for j=1:n
|
|
|
if sum(ismember(indDelete,j))>0
|
|
|
%skip - already due for deletion
|
|
|
else
|
|
|
%...if multiple dictionary elements are similar and of the same class delete
|
|
|
%them and add a new element (with the same class label) that is the mean
|
|
|
%of the similar elements
|
|
|
indSimilarAndSameClass=NCC(j,:)>0.975 & class==class(j);
|
|
|
if sum(indSimilarAndSameClass)>1
|
|
|
indDelete=[indDelete,find(indSimilarAndSameClass==1)];
|
|
|
Vnew=[Vnew;mean(V(indSimilarAndSameClass,:))];
|
|
|
classNew=[classNew,class(j)];
|
|
|
end
|
|
|
end
|
|
|
end
|
|
|
%reorganise dictionary, deleting similar elements and adding new (merged) elements
|
|
|
indKeep=ones(n,1);
|
|
|
indKeep(indDelete)=0;
|
|
|
indKeep=logical(indKeep);
|
|
|
classDict=class(indKeep);
|
|
|
classDict=[classDict,classNew];
|
|
|
V=V(indKeep,:);
|
|
|
V=[V;Vnew];
|
|
|
|
|
|
otherwise
|
|
|
disp('ERROR: no method specified for creating dictionary');
|
|
|
end
|
|
|
|
|
|
%show elements of the dictionary
|
|
|
figure(1),clf
|
|
|
k=0;
|
|
|
for c=1:numClasses
|
|
|
if ~isempty(classDict)
|
|
|
inClass=find(classDict==c);
|
|
|
toplot=min(10,length(inClass));
|
|
|
else
|
|
|
toplot=min(10,nodesPerClass);
|
|
|
end
|
|
|
for j=1:toplot
|
|
|
maxsubplot(numClasses,toplot,(c-1)*toplot+j),
|
|
|
if ~isempty(classDict)
|
|
|
plot_weights(V(inClass(j),1:m));
|
|
|
else
|
|
|
k=k+1;plot_weights(V(k,1:m));
|
|
|
end
|
|
|
end
|
|
|
end
|
|
|
cmap=colormap('gray');cmap=1-cmap;colormap(cmap);
|
|
|
set(gcf,'PaperPosition',[1 1 21 21]);
|
|
|
|
|
|
drawnow;
|
|
|
|
|
|
disp(' ');
|
|
|
[n,m]=size(V);
|
|
|
disp(['dictionary has ',num2str(n),' elements of length ',num2str(m)]);
|
|
|
|