You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

407 lines
15 KiB
Matlab

function train_ism(task)
% Performs the training process of the implicit shape model (ISM).
% The input specifies the training image set to use.
% The output is a .mat file containing the codebook for the specified training image set.
if nargin<1 || isempty(task), task='cars'; end
%LOAD TRAINING IMAGES
[TrainingImages,imageClassList,objLocation,objRegion]=load_dataset(task);
%EXTRACT PATCHES FROM AROUND KEYPOINTS IN THE TRAINING IMAGES
switch task
case 'carsScale'
patchHalfLen=8;
otherwise
patchHalfLen=7;
end
[patches,patchClassList]=extract_patches(TrainingImages,imageClassList,objRegion,patchHalfLen);
%patchClassList(patchClassList>1)=1; %patches are non class-specific
%CLUSTER PATCHES TO FORM CODEBOOK
numClusterMembers=12;
similarityThres=0.4;
[patches,patchClassList]=cluster_patches(patches,patchClassList,numClusterMembers,similarityThres);
%FOR EACH CODEBOOK ENTRY LOCATE CENTRE OF OBJECT RELATVE TO PATCH POSITION
Locations=extract_locations(TrainingImages,imageClassList,objLocation,patchHalfLen,patches,patchClassList,similarityThres);
%SAVE CODEBOOK TO DISK
filename=['ISM_codebook_',task,'.mat']
save(filename,'-v7.3','patches','patchClassList','Locations','similarityThres')
%show codebook entries: patches and vote positions
figured(1),clf,plot_codebook(patches,patchClassList,Locations);
function [TrainingImages,classList,objLocation,objRegion]=load_dataset(task);
%training images should contain one object only
switch task
case 'cars'
%background images
numBackgnd=500;
for i=1:numBackgnd
TrainingImages{i}=load_image_car(i-1,2);
classList(i)=0;
objRegion{i}=ones(size(TrainingImages{i}),'single');
end
%car images
for i=1:550
TrainingImages{numBackgnd+i}=load_image_car(i-1,1);
classList(numBackgnd+i)=1;
objLocation(numBackgnd+i,1:2)=size(TrainingImages{numBackgnd+i})./2;%assumes all objects are centred in the training images.
objRegion{numBackgnd+i}=ones(size(TrainingImages{numBackgnd+i}),'single');
end
case 'carsScale'
TrainingImages=[];classList=[];objLocation=[];objRegion=[];
%background images
numBackgnd=500;
for i=1:numBackgnd
I=load_image_car(i-1,2);
[TrainingImages,classList,objLocation,objRegion]=expand_car_image_scales(0,i,TrainingImages,classList,[],I,0,objRegion);
end
%car images
for i=1:550
I=load_image_car(i-1,1);
[TrainingImages,classList,objLocation,objRegion]=expand_car_image_scales(numBackgnd,i,TrainingImages,classList,objLocation,I,1,objRegion);
end
case 'horses'
TrainingImages=[];classList=[];objLocation=[];objRegion=[];
numBackgnd=100;
for i=1:numBackgnd
%I=load_image_car(i-1,2);
I=load_image_101(i);
%[TrainingImages,classList,objLocation,objRegion]=expand_horse_image_set(0,i,TrainingImages,classList,objLocation,objRegion,I,[],0);
[TrainingImages,classList,objLocation,objRegion]=expand_affine_image_set(0,i,TrainingImages,classList,objLocation,objRegion,I,[],0);
%TrainingImages{i}=I;
%classList(i)=0;
%objRegion{i}=ones(size(TrainingImages{i}),'single');
end
%horse images (use first 100 for training, next 228 will be for testing)
for i=1:100
[I,groundtruth]=load_image_horse(i);
%[TrainingImages,classList,objLocation,objRegion]=expand_horse_image_set(numBackgnd,i,TrainingImages,classList,objLocation,objRegion,I,groundtruth,1);
[TrainingImages,classList,objLocation,objRegion]=expand_affine_image_set(numBackgnd,i,TrainingImages,classList,objLocation,objRegion,I,groundtruth,1);
%TrainingImages{numBackgnd+i}=I;
%classList(numBackgnd+i)=1;
%objLocation(numBackgnd+i,1:2)=fliplr(centroid(groundtruth));
%objRegion{numBackgnd+i}=single(imdilate(groundtruth,strel('disk',4)));
end
case 'USPS'
[images,classes,trainingIndeces,testingIndeces,numClasses]=load_data_usps;
numBackgnd=0;
TrainingImages=[];classList=[];objLocation=[];objRegion=[];
for i=trainingIndeces
i
I=preprocess_usps_image(images(i,:));
[TrainingImages,classList,objLocation,objRegion]=expand_affine_image_set(numBackgnd,i,TrainingImages,classList,objLocation,objRegion,I,I,classes(i));
%TrainingImages{numBackgnd+i}=I;
%classList(numBackgnd+i)=classes(i);
%objLocation(numBackgnd+i,1:2)=size(TrainingImages{numBackgnd+i})./2;%assumes all objects are centred in the training images.
%objRegion{numBackgnd+i}=single(imdilate(I,strel('disk',4)));
end
case 'NORB'
[images,classes,numClasses]=load_data_norb(1);
numBackgnd=0;
for i=1:length(classes)
TrainingImages{numBackgnd+i}=reshape(images(i,:),96,96)';
classList(numBackgnd+i)=classes(i);
objLocation(numBackgnd+i,1:2)=size(TrainingImages{numBackgnd+i})./2;%assumes all objects are centred in the training images.
objRegion{numBackgnd+i}=ones(size(TrainingImages{numBackgnd+i}),'single');
end
otherwise
disp('ERROR: unknown data set');
end
disp(['loaded ',int2str(length(TrainingImages)),' training images']);
function [patches,patchClassList]=extract_patches(I,imageClassList,objRegion,hlen)
numImages=length(I);
numPatchesPerImage=200;
%pre-allocate (hopefully) too much memory
estNumPatchesTotal=numImages*numPatchesPerImage;
len=1+2*hlen;
patches=zeros(estNumPatchesTotal,len^2,'single');
patchClassList=zeros(1,estNumPatchesTotal,'single');
j=0;
for i=1:numImages
%for each image find the keypoints
keypoints=extract_keypoints(I{i},numPatchesPerImage);
for k=1:size(keypoints,1);
%for each keypoint extract a patch of image around that keypoint
Ipatch=extract_patch(I{i},hlen,keypoints(k,1:2));
if ~isnan(Ipatch) %& objRegion{i}(keypoints(k,1),keypoints(k,2))>0.5
%keep record of patch and class of image it came from
j=j+1;
patches(j,:)=Ipatch(:)';
if objRegion{i}(keypoints(k,1),keypoints(k,2))>0.5
patchClassList(j)=imageClassList(i);
else
patchClassList(j)=0;
end
end
end
end
patches=patches(1:j,:);
patchClassList=patchClassList(1:j);
disp(['extracted ',int2str(j),' image patches (including ',int2str(sum(patchClassList==0)),' for class 0)'])
function [clustersAllClasses,clusterClassList]=cluster_patches(patches,patchClassList,numClusterMembersReqd,similarityThres);
clustersAllClasses=[];
clusterClassList=[];
%extract class-specific patches
classes=unique(patchClassList);
for c=classes
indPatchesInClass=find(patchClassList==c);
%cluster patches, to find cluster centres
[clusters,len]=agglomerative_clustering(patches(indPatchesInClass,:),similarityThres);
%remove clusters that are similar to clusters found in background images
if c>0 && exist('patchesBackgnd','var');
numClusters=size(clusters,1);
simBackgnd=zeros(1,numClusters);
for i=1:numClusters
%find similarity of closest patch in background
simBackgnd(i)=max(distance_measure(clusters(i,:),patchesBackgnd,1));
end
keep=find(simBackgnd<0.93);
clusters=clusters(keep,:);
len=len(keep);
%numClusters=length(keep)
end
%keep only those clusters that contain the most exemplars - i.e. the most frequently encountered patches
keep=find(len>numClusterMembersReqd);
clustersAllClasses=[clustersAllClasses;clusters(keep,:)];
clusterClassList=[clusterClassList,c.*ones(1,length(keep))];
%len(keep)
%remember clusters extracted from background images
if c==0, patchesBackgnd=clustersAllClasses; end
disp(['extracted ',int2str(length(keep)),' clusters for class ',int2str(c)])
end
function [clusters,len]=agglomerative_clustering(data,similarityThres)
%Performs hierarchical agglomerative clustering by calling the built-in matlab
%function "clusterdata". However, this method of clustering is very memory
%hungry due to the calculation of the similarity of every pair of
%samples. Hence, if there are lots of samples, clustering is performed multiple
%times on non-overlapping subsets of samples and the results are merged - a kludge!
maxNumPerClustering=40000;
numSamplesTotal=size(data,1);
numSamplesPerClustering=numSamplesTotal;
repeats=0;
while numSamplesPerClustering>maxNumPerClustering
numSamplesPerClustering=ceil(numSamplesPerClustering/2);
repeats=repeats+1;
end
numSamplesPerClustering;
repeats=2^repeats;
if repeats>1,
disp(['WARNING: due to memory limitations APPROXIMATE clustering being performed using ',...
int2str(repeats),' iterations']);
end
clusters=[];
len=[];
for it=1:repeats
indToCluster=1+(it-1)*numSamplesPerClustering:min(it*numSamplesPerClustering,numSamplesTotal);
patchesToCluster=data(indToCluster,:);
clusterIndex = clusterdata(patchesToCluster,'cutoff',1-similarityThres,'criterion','distance','linkage','complete','distance',@distance_measure);%note clustering is based on distance, not similarity, so need to use 1-similarity threshold
%place all patches into their clusters
numClusters=max(clusterIndex);
newClusters=[];newLen=[];
for i=1:numClusters
indClusterMembers=find(clusterIndex==i);
newClusters(i,:)=mean(patchesToCluster(indClusterMembers,:),1); %each cluster is mean of all members
newLen(i)=length(indClusterMembers);
end
%remove clusters that are similar to those already found on previous iterations
if it>1
simExisting=zeros(1,numClusters);
for i=1:numClusters
%find similarity of closest patch in exiting clusters
[simExisting(i),indExisting(i)]=max(distance_measure(newClusters(i,:),clusters,1));
end
addto=find(simExisting>=0.97);
len(indExisting(addto))=len(indExisting(addto))+newLen(addto);
keep=find(simExisting<0.97);
newClusters=newClusters(keep,:);
newLen=newLen(keep);
end
%add remaining clusters to those found on previous iterations
clusters=[clusters;newClusters];
len=[len,newLen];
end
function locations=extract_locations(I,imageClassList,objLocation,hlen,patches,patchClassList,similarityThres)
%defines the locations where votes will be cast relative to the image
%patch/keypoint
maxMatches=1;
numPatchesPerImage=200;
classes=unique(imageClassList);
if classes(1)==0, classes=classes(2:end); end %class 0 is background - no voting required
for c=classes
indImagesInClass=find(imageClassList==c); %all images in class
indPatchesInClass=find(patchClassList>0); %all non-background patches
%initialise voting weights for each codebook entry
A=0;B=0;
for i=1:length(indImagesInClass)
[a,b]=size(I{indImagesInClass(i)});
A=max(A,a);B=max(B,b);
end
%A=2*A;B=2*B;
for j=1:length(indPatchesInClass)
locations{c}{j}=zeros(A,B,'single');
end
totalLocations=0;
for i=1:length(indImagesInClass)
%extract keypoints from training image
keypoints=extract_keypoints(I{indImagesInClass(i)},numPatchesPerImage);
[a,b]=size(I{indImagesInClass(i)});
for k=1:size(keypoints,1);
%compare patch around keypoint to the patches in codebook
Ipatch=extract_patch(I{indImagesInClass(i)},hlen,keypoints(k,1:2));
if ~isnan(Ipatch)
similarity=distance_measure(Ipatch(:)',patches,1);
%[matchingStrength,matchingCluster]=max(similarity);
[matchingStrength,matchingCluster]=sort(similarity,'descend');
for m=1:maxMatches
%if patch matches the codebook record as a vote in that patch's voting array
if patchClassList(matchingCluster(m))>0 && matchingStrength(m)>0 %similarityThres
%calculate position of object centre relative to patch and add
%weight at this offset from the centre of the vote weight array
voteLocation=round([A,B]./2+objLocation(indImagesInClass(i),:)-keypoints(k,1:2));
if voteLocation(1)>=1 && voteLocation(2)>=1 && voteLocation(1)<=A && voteLocation(2)<=B
matchedClassCluster=find(indPatchesInClass==matchingCluster(m));
locations{c}{matchedClassCluster}(voteLocation(1),voteLocation(2))=...
max(locations{c}{matchedClassCluster}(voteLocation(1),voteLocation(2)),1);
totalLocations=totalLocations+1;
else
disp('WARNING: location fell off edge of vote array'); voteLocation
end
else
break;%jump out of for-loop the first time the matching criteria is failed
end
end
end
end
end
disp(['extracted ',int2str(totalLocations),' locations for class ',int2str(c)])
%%count the number of votes corresponding to each cluster
%for j=1:length(indPatchesInClass)
% len(j)=sum(locations{c}{j}(:));
%end
%len
end
function locations=extract_locations_optsize(I,imageClassList,objLocation,hlen,patches,patchClassList,similarityThres)
%defines the locations where votes will be cast relative to the image
%patch/keypoint
maxMatches=1;
numPatchesPerImage=200;
classes=unique(imageClassList);
if classes(1)==0, classes=classes(2:end); end %class 0 is background - no voting required
for c=classes
indImagesInClass=find(imageClassList==c); %all images in class
indPatchesInClass=find(patchClassList>0); %all non-background patches
%initialise voting weights for each codebook entry
for j=1:length(indPatchesInClass)
locations{c}{j}=single(0);
end
totalLocations=0;
for i=1:length(indImagesInClass)
%extract keypoints from training image
keypoints=extract_keypoints(I{indImagesInClass(i)},numPatchesPerImage);
[a,b]=size(I{indImagesInClass(i)});
for k=1:size(keypoints,1);
%compare patch around keypoint to the patches in codebook
Ipatch=extract_patch(I{indImagesInClass(i)},hlen,keypoints(k,1:2));
if ~isnan(Ipatch)
similarity=distance_measure(Ipatch(:)',patches,1);
%[matchingStrength,matchingCluster]=max(similarity);
[matchingStrength,matchingCluster]=sort(similarity,'descend');
for m=1:maxMatches
%if patch matches the codebook record as a vote in that patch's voting array
if patchClassList(matchingCluster(m))>0 && matchingStrength(m)>0 %similarityThres
%calculate position of object centre relative to patch
voteLocation=objLocation(indImagesInClass(i),:)-keypoints(k,1:2);
%add weight at this offset from the centre of the vote weight array
matchedClassCluster=find(indPatchesInClass==matchingCluster(m));
locations{c}{matchedClassCluster}=...
expanding_location_update(locations{c}{matchedClassCluster},voteLocation);
totalLocations=totalLocations+1;
else
break;%jump out of for-loop the first time the matching criteria is failed
end
end
end
end
end
disp(['extracted ',int2str(totalLocations),' locations for class ',int2str(c)])
%%count the number of votes corresponding to each cluster
%for j=1:length(indPatchesInClass)
% len(j)=sum(locations{c}{j}(:));
%end
%len
end
function L=expanding_location_update(L,voteOffset)
%calculate location of vote relative to centre of vote weight array
[A,B]=size(L);
voteLocation=ceil([A,B]./2+voteOffset);
%expand vote weight array if new vote location is beyond current borders
overhangs(1)=max(0,max([1-voteLocation(1),voteLocation(1)-A]));
overhangs(2)=max(0,max([1-voteLocation(2),voteLocation(2)-B]));
L=padarray(L,overhangs,0,'both');
%recalculate vote position, and add vote
[A,B]=size(L);
voteLocation=ceil([A,B]./2+voteOffset);
L(voteLocation(1),voteLocation(2))=max(L(voteLocation(1),voteLocation(2)),1);