You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

151 lines
4.6 KiB
Matlab

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

function [V,classDict]=define_dictionary_weights(data,class)
[m,numPatterns]=size(data);
numClasses=max(class);
GP=global_parameters;
switch GP.dictionary
case 'all'
%DEFINE THE DICTIONARY
%use all data in the training set as elements of the dictionary
V=data';
W=define_pcbc_feedforward_weights(V);
%RECORD THE CLASS OF EACH DICTIONARY ELEMENT
classDict=class;
case 'learn'
%Learn the dictionary using the ILS-DLA/MOD method Engan, Skretting and Husøy
%(2007) Digital Signal Processing 17 (2007) 3249. See also:
%http://www.ux.uis.no/~karlsk/dle/
nodesPerClass=100 %size of class specific sub-dictionaries (all equal for convienience)
epochs=25 %number of passes through the training data set
%RECORD THE CLASS OF EACH DICTIONARY ELEMENT
classDict=[];
for c=1:numClasses
classDict(1+(c-1)*nodesPerClass:c*nodesPerClass)=c;
end
%DEFINE THE DICTIONARY
%learn separate dictionary for each class
fprintf(1,'learning dictionary entries: ');
for c=1:numClasses
fprintf(1,'.%i.',c);
%extract data for one class to use as training data
ind=logical(class==c);
dataC=data(:,ind);
numPatternsC=size(dataC,2);
%define initial weights for one class
VC=dataC(:,1:nodesPerClass)';
VC=norm_dictionary(VC);
%learn weights for each class-specific sub-dictionary in turn
for k=1:ceil(epochs)
%calculate sparse representations using current dictionary
clear YC;
repeat=1;
while repeat
for l=1:numPatternsC
[YC(:,l),e,s,nmse(l)]=calc_sparse_representation(VC,dataC(:,l),ones(1,nodesPerClass),0);
end
[val,ind]=min(max(abs(YC),[],2));
if val==0
%one of the dictionary elements fails to respond to any inputs: replace with training vector with the higest reconstruction error
disp('replacing dictionary element'), k
[val,new]=max(nmse);
VC(ind,:)=dataC(:,new)';
else
repeat=0;
end
end
%update dictionary
VC = ((dataC*YC')*inv(YC*YC'))';
%VC = ((dataC*YC')/(YC*YC'))';
VC(VC<0)=0; %constrain dictionary to be non-negative
VC=norm_dictionary(VC);
%display useful data about current dictionary
recon=VC'*YC;
nmse=(sum((dataC(:)-recon(:)).^2))./(1e-9+sum(dataC(:).^2));
sparsity=measure_sparsity_hoyer(YC);
disp([ ' mean NMSE=',num2str(nmse),...
' mean Sparsity=',num2str(sparsity)]);
end
%insert learned weights back into overall dictionary
ind=logical(classDict==c);
V(ind,:)=VC;
end
case 'selective'
%DEFINE THE DICTIONARY
%start with a dictionary containing all elements of the training set...
V=data';
%... then use the normalised cross-correlation to measure the similarity of
%dictionary elements...
Vnorm=norm_dictionary(V,2);
NCC=Vnorm*Vnorm';
indDelete=[];
Vnew=[];
classNew=[];
[n,m]=size(V);
for j=1:n
if sum(ismember(indDelete,j))>0
%skip - already due for deletion
else
%...if multiple dictionary elements are similar and of the same class delete
%them and add a new element (with the same class label) that is the mean
%of the similar elements
indSimilarAndSameClass=NCC(j,:)>0.975 & class==class(j);
if sum(indSimilarAndSameClass)>1
indDelete=[indDelete,find(indSimilarAndSameClass==1)];
Vnew=[Vnew;mean(V(indSimilarAndSameClass,:))];
classNew=[classNew,class(j)];
end
end
end
%reorganise dictionary, deleting similar elements and adding new (merged) elements
indKeep=ones(n,1);
indKeep(indDelete)=0;
indKeep=logical(indKeep);
classDict=class(indKeep);
classDict=[classDict,classNew];
V=V(indKeep,:);
V=[V;Vnew];
otherwise
disp('ERROR: no method specified for creating dictionary');
end
%show elements of the dictionary
figure(1),clf
k=0;
for c=1:numClasses
if ~isempty(classDict)
inClass=find(classDict==c);
toplot=min(10,length(inClass));
else
toplot=min(10,nodesPerClass);
end
for j=1:toplot
maxsubplot(numClasses,toplot,(c-1)*toplot+j),
if ~isempty(classDict)
plot_weights(V(inClass(j),1:m));
else
k=k+1;plot_weights(V(k,1:m));
end
end
end
cmap=colormap('gray');cmap=1-cmap;colormap(cmap);
set(gcf,'PaperPosition',[1 1 21 21]);
drawnow;
disp(' ');
[n,m]=size(V);
disp(['dictionary has ',num2str(n),' elements of length ',num2str(m)]);