You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
407 lines
15 KiB
Matlab
407 lines
15 KiB
Matlab
function train_ism(task)
|
|
% Performs the training process of the implicit shape model (ISM).
|
|
% The input specifies the training image set to use.
|
|
% The output is a .mat file containing the codebook for the specified training image set.
|
|
|
|
if nargin<1 || isempty(task), task='cars'; end
|
|
|
|
%LOAD TRAINING IMAGES
|
|
[TrainingImages,imageClassList,objLocation,objRegion]=load_dataset(task);
|
|
|
|
%EXTRACT PATCHES FROM AROUND KEYPOINTS IN THE TRAINING IMAGES
|
|
switch task
|
|
case 'carsScale'
|
|
patchHalfLen=8;
|
|
otherwise
|
|
patchHalfLen=7;
|
|
end
|
|
[patches,patchClassList]=extract_patches(TrainingImages,imageClassList,objRegion,patchHalfLen);
|
|
%patchClassList(patchClassList>1)=1; %patches are non class-specific
|
|
|
|
%CLUSTER PATCHES TO FORM CODEBOOK
|
|
numClusterMembers=12;
|
|
similarityThres=0.4;
|
|
[patches,patchClassList]=cluster_patches(patches,patchClassList,numClusterMembers,similarityThres);
|
|
|
|
%FOR EACH CODEBOOK ENTRY LOCATE CENTRE OF OBJECT RELATVE TO PATCH POSITION
|
|
Locations=extract_locations(TrainingImages,imageClassList,objLocation,patchHalfLen,patches,patchClassList,similarityThres);
|
|
|
|
%SAVE CODEBOOK TO DISK
|
|
filename=['ISM_codebook_',task,'.mat']
|
|
save(filename,'-v7.3','patches','patchClassList','Locations','similarityThres')
|
|
|
|
%show codebook entries: patches and vote positions
|
|
figured(1),clf,plot_codebook(patches,patchClassList,Locations);
|
|
|
|
|
|
|
|
|
|
function [TrainingImages,classList,objLocation,objRegion]=load_dataset(task);
|
|
%training images should contain one object only
|
|
|
|
switch task
|
|
|
|
case 'cars'
|
|
%background images
|
|
numBackgnd=500;
|
|
for i=1:numBackgnd
|
|
TrainingImages{i}=load_image_car(i-1,2);
|
|
classList(i)=0;
|
|
objRegion{i}=ones(size(TrainingImages{i}),'single');
|
|
end
|
|
%car images
|
|
for i=1:550
|
|
TrainingImages{numBackgnd+i}=load_image_car(i-1,1);
|
|
classList(numBackgnd+i)=1;
|
|
objLocation(numBackgnd+i,1:2)=size(TrainingImages{numBackgnd+i})./2;%assumes all objects are centred in the training images.
|
|
objRegion{numBackgnd+i}=ones(size(TrainingImages{numBackgnd+i}),'single');
|
|
end
|
|
|
|
case 'carsScale'
|
|
TrainingImages=[];classList=[];objLocation=[];objRegion=[];
|
|
%background images
|
|
numBackgnd=500;
|
|
for i=1:numBackgnd
|
|
I=load_image_car(i-1,2);
|
|
[TrainingImages,classList,objLocation,objRegion]=expand_car_image_scales(0,i,TrainingImages,classList,[],I,0,objRegion);
|
|
end
|
|
%car images
|
|
for i=1:550
|
|
I=load_image_car(i-1,1);
|
|
[TrainingImages,classList,objLocation,objRegion]=expand_car_image_scales(numBackgnd,i,TrainingImages,classList,objLocation,I,1,objRegion);
|
|
end
|
|
|
|
case 'horses'
|
|
TrainingImages=[];classList=[];objLocation=[];objRegion=[];
|
|
numBackgnd=100;
|
|
for i=1:numBackgnd
|
|
%I=load_image_car(i-1,2);
|
|
I=load_image_101(i);
|
|
%[TrainingImages,classList,objLocation,objRegion]=expand_horse_image_set(0,i,TrainingImages,classList,objLocation,objRegion,I,[],0);
|
|
[TrainingImages,classList,objLocation,objRegion]=expand_affine_image_set(0,i,TrainingImages,classList,objLocation,objRegion,I,[],0);
|
|
%TrainingImages{i}=I;
|
|
%classList(i)=0;
|
|
%objRegion{i}=ones(size(TrainingImages{i}),'single');
|
|
end
|
|
%horse images (use first 100 for training, next 228 will be for testing)
|
|
for i=1:100
|
|
[I,groundtruth]=load_image_horse(i);
|
|
%[TrainingImages,classList,objLocation,objRegion]=expand_horse_image_set(numBackgnd,i,TrainingImages,classList,objLocation,objRegion,I,groundtruth,1);
|
|
[TrainingImages,classList,objLocation,objRegion]=expand_affine_image_set(numBackgnd,i,TrainingImages,classList,objLocation,objRegion,I,groundtruth,1);
|
|
%TrainingImages{numBackgnd+i}=I;
|
|
%classList(numBackgnd+i)=1;
|
|
%objLocation(numBackgnd+i,1:2)=fliplr(centroid(groundtruth));
|
|
%objRegion{numBackgnd+i}=single(imdilate(groundtruth,strel('disk',4)));
|
|
end
|
|
|
|
case 'USPS'
|
|
[images,classes,trainingIndeces,testingIndeces,numClasses]=load_data_usps;
|
|
numBackgnd=0;
|
|
TrainingImages=[];classList=[];objLocation=[];objRegion=[];
|
|
for i=trainingIndeces
|
|
i
|
|
I=preprocess_usps_image(images(i,:));
|
|
[TrainingImages,classList,objLocation,objRegion]=expand_affine_image_set(numBackgnd,i,TrainingImages,classList,objLocation,objRegion,I,I,classes(i));
|
|
%TrainingImages{numBackgnd+i}=I;
|
|
%classList(numBackgnd+i)=classes(i);
|
|
%objLocation(numBackgnd+i,1:2)=size(TrainingImages{numBackgnd+i})./2;%assumes all objects are centred in the training images.
|
|
%objRegion{numBackgnd+i}=single(imdilate(I,strel('disk',4)));
|
|
end
|
|
|
|
case 'NORB'
|
|
[images,classes,numClasses]=load_data_norb(1);
|
|
numBackgnd=0;
|
|
for i=1:length(classes)
|
|
TrainingImages{numBackgnd+i}=reshape(images(i,:),96,96)';
|
|
classList(numBackgnd+i)=classes(i);
|
|
objLocation(numBackgnd+i,1:2)=size(TrainingImages{numBackgnd+i})./2;%assumes all objects are centred in the training images.
|
|
objRegion{numBackgnd+i}=ones(size(TrainingImages{numBackgnd+i}),'single');
|
|
end
|
|
|
|
otherwise
|
|
disp('ERROR: unknown data set');
|
|
end
|
|
disp(['loaded ',int2str(length(TrainingImages)),' training images']);
|
|
|
|
|
|
function [patches,patchClassList]=extract_patches(I,imageClassList,objRegion,hlen)
|
|
numImages=length(I);
|
|
numPatchesPerImage=200;
|
|
|
|
%pre-allocate (hopefully) too much memory
|
|
estNumPatchesTotal=numImages*numPatchesPerImage;
|
|
len=1+2*hlen;
|
|
patches=zeros(estNumPatchesTotal,len^2,'single');
|
|
patchClassList=zeros(1,estNumPatchesTotal,'single');
|
|
|
|
j=0;
|
|
for i=1:numImages
|
|
%for each image find the keypoints
|
|
keypoints=extract_keypoints(I{i},numPatchesPerImage);
|
|
|
|
for k=1:size(keypoints,1);
|
|
%for each keypoint extract a patch of image around that keypoint
|
|
Ipatch=extract_patch(I{i},hlen,keypoints(k,1:2));
|
|
|
|
if ~isnan(Ipatch) %& objRegion{i}(keypoints(k,1),keypoints(k,2))>0.5
|
|
%keep record of patch and class of image it came from
|
|
j=j+1;
|
|
patches(j,:)=Ipatch(:)';
|
|
if objRegion{i}(keypoints(k,1),keypoints(k,2))>0.5
|
|
patchClassList(j)=imageClassList(i);
|
|
else
|
|
patchClassList(j)=0;
|
|
end
|
|
end
|
|
end
|
|
end
|
|
patches=patches(1:j,:);
|
|
patchClassList=patchClassList(1:j);
|
|
disp(['extracted ',int2str(j),' image patches (including ',int2str(sum(patchClassList==0)),' for class 0)'])
|
|
|
|
|
|
|
|
function [clustersAllClasses,clusterClassList]=cluster_patches(patches,patchClassList,numClusterMembersReqd,similarityThres);
|
|
clustersAllClasses=[];
|
|
clusterClassList=[];
|
|
|
|
%extract class-specific patches
|
|
classes=unique(patchClassList);
|
|
for c=classes
|
|
indPatchesInClass=find(patchClassList==c);
|
|
|
|
%cluster patches, to find cluster centres
|
|
[clusters,len]=agglomerative_clustering(patches(indPatchesInClass,:),similarityThres);
|
|
|
|
%remove clusters that are similar to clusters found in background images
|
|
if c>0 && exist('patchesBackgnd','var');
|
|
numClusters=size(clusters,1);
|
|
simBackgnd=zeros(1,numClusters);
|
|
for i=1:numClusters
|
|
%find similarity of closest patch in background
|
|
simBackgnd(i)=max(distance_measure(clusters(i,:),patchesBackgnd,1));
|
|
end
|
|
keep=find(simBackgnd<0.93);
|
|
clusters=clusters(keep,:);
|
|
len=len(keep);
|
|
%numClusters=length(keep)
|
|
end
|
|
|
|
%keep only those clusters that contain the most exemplars - i.e. the most frequently encountered patches
|
|
keep=find(len>numClusterMembersReqd);
|
|
clustersAllClasses=[clustersAllClasses;clusters(keep,:)];
|
|
clusterClassList=[clusterClassList,c.*ones(1,length(keep))];
|
|
%len(keep)
|
|
|
|
%remember clusters extracted from background images
|
|
if c==0, patchesBackgnd=clustersAllClasses; end
|
|
|
|
disp(['extracted ',int2str(length(keep)),' clusters for class ',int2str(c)])
|
|
end
|
|
|
|
|
|
|
|
|
|
function [clusters,len]=agglomerative_clustering(data,similarityThres)
|
|
%Performs hierarchical agglomerative clustering by calling the built-in matlab
|
|
%function "clusterdata". However, this method of clustering is very memory
|
|
%hungry due to the calculation of the similarity of every pair of
|
|
%samples. Hence, if there are lots of samples, clustering is performed multiple
|
|
%times on non-overlapping subsets of samples and the results are merged - a kludge!
|
|
|
|
maxNumPerClustering=40000;
|
|
numSamplesTotal=size(data,1);
|
|
numSamplesPerClustering=numSamplesTotal;
|
|
|
|
repeats=0;
|
|
while numSamplesPerClustering>maxNumPerClustering
|
|
numSamplesPerClustering=ceil(numSamplesPerClustering/2);
|
|
repeats=repeats+1;
|
|
end
|
|
numSamplesPerClustering;
|
|
repeats=2^repeats;
|
|
|
|
if repeats>1,
|
|
disp(['WARNING: due to memory limitations APPROXIMATE clustering being performed using ',...
|
|
int2str(repeats),' iterations']);
|
|
end
|
|
|
|
clusters=[];
|
|
len=[];
|
|
for it=1:repeats
|
|
indToCluster=1+(it-1)*numSamplesPerClustering:min(it*numSamplesPerClustering,numSamplesTotal);
|
|
patchesToCluster=data(indToCluster,:);
|
|
|
|
clusterIndex = clusterdata(patchesToCluster,'cutoff',1-similarityThres,'criterion','distance','linkage','complete','distance',@distance_measure);%note clustering is based on distance, not similarity, so need to use 1-similarity threshold
|
|
|
|
%place all patches into their clusters
|
|
numClusters=max(clusterIndex);
|
|
newClusters=[];newLen=[];
|
|
for i=1:numClusters
|
|
indClusterMembers=find(clusterIndex==i);
|
|
newClusters(i,:)=mean(patchesToCluster(indClusterMembers,:),1); %each cluster is mean of all members
|
|
newLen(i)=length(indClusterMembers);
|
|
end
|
|
|
|
%remove clusters that are similar to those already found on previous iterations
|
|
if it>1
|
|
simExisting=zeros(1,numClusters);
|
|
for i=1:numClusters
|
|
%find similarity of closest patch in exiting clusters
|
|
[simExisting(i),indExisting(i)]=max(distance_measure(newClusters(i,:),clusters,1));
|
|
end
|
|
addto=find(simExisting>=0.97);
|
|
len(indExisting(addto))=len(indExisting(addto))+newLen(addto);
|
|
keep=find(simExisting<0.97);
|
|
newClusters=newClusters(keep,:);
|
|
newLen=newLen(keep);
|
|
end
|
|
|
|
%add remaining clusters to those found on previous iterations
|
|
clusters=[clusters;newClusters];
|
|
len=[len,newLen];
|
|
end
|
|
|
|
|
|
|
|
function locations=extract_locations(I,imageClassList,objLocation,hlen,patches,patchClassList,similarityThres)
|
|
%defines the locations where votes will be cast relative to the image
|
|
%patch/keypoint
|
|
maxMatches=1;
|
|
numPatchesPerImage=200;
|
|
classes=unique(imageClassList);
|
|
if classes(1)==0, classes=classes(2:end); end %class 0 is background - no voting required
|
|
for c=classes
|
|
indImagesInClass=find(imageClassList==c); %all images in class
|
|
indPatchesInClass=find(patchClassList>0); %all non-background patches
|
|
|
|
%initialise voting weights for each codebook entry
|
|
A=0;B=0;
|
|
for i=1:length(indImagesInClass)
|
|
[a,b]=size(I{indImagesInClass(i)});
|
|
A=max(A,a);B=max(B,b);
|
|
end
|
|
%A=2*A;B=2*B;
|
|
for j=1:length(indPatchesInClass)
|
|
locations{c}{j}=zeros(A,B,'single');
|
|
end
|
|
|
|
totalLocations=0;
|
|
for i=1:length(indImagesInClass)
|
|
%extract keypoints from training image
|
|
keypoints=extract_keypoints(I{indImagesInClass(i)},numPatchesPerImage);
|
|
|
|
[a,b]=size(I{indImagesInClass(i)});
|
|
|
|
for k=1:size(keypoints,1);
|
|
%compare patch around keypoint to the patches in codebook
|
|
Ipatch=extract_patch(I{indImagesInClass(i)},hlen,keypoints(k,1:2));
|
|
|
|
if ~isnan(Ipatch)
|
|
similarity=distance_measure(Ipatch(:)',patches,1);
|
|
%[matchingStrength,matchingCluster]=max(similarity);
|
|
[matchingStrength,matchingCluster]=sort(similarity,'descend');
|
|
for m=1:maxMatches
|
|
%if patch matches the codebook record as a vote in that patch's voting array
|
|
if patchClassList(matchingCluster(m))>0 && matchingStrength(m)>0 %similarityThres
|
|
%calculate position of object centre relative to patch and add
|
|
%weight at this offset from the centre of the vote weight array
|
|
voteLocation=round([A,B]./2+objLocation(indImagesInClass(i),:)-keypoints(k,1:2));
|
|
|
|
if voteLocation(1)>=1 && voteLocation(2)>=1 && voteLocation(1)<=A && voteLocation(2)<=B
|
|
matchedClassCluster=find(indPatchesInClass==matchingCluster(m));
|
|
locations{c}{matchedClassCluster}(voteLocation(1),voteLocation(2))=...
|
|
max(locations{c}{matchedClassCluster}(voteLocation(1),voteLocation(2)),1);
|
|
totalLocations=totalLocations+1;
|
|
else
|
|
disp('WARNING: location fell off edge of vote array'); voteLocation
|
|
end
|
|
else
|
|
break;%jump out of for-loop the first time the matching criteria is failed
|
|
end
|
|
end
|
|
end
|
|
end
|
|
end
|
|
disp(['extracted ',int2str(totalLocations),' locations for class ',int2str(c)])
|
|
|
|
%%count the number of votes corresponding to each cluster
|
|
%for j=1:length(indPatchesInClass)
|
|
% len(j)=sum(locations{c}{j}(:));
|
|
%end
|
|
%len
|
|
end
|
|
|
|
function locations=extract_locations_optsize(I,imageClassList,objLocation,hlen,patches,patchClassList,similarityThres)
|
|
%defines the locations where votes will be cast relative to the image
|
|
%patch/keypoint
|
|
maxMatches=1;
|
|
numPatchesPerImage=200;
|
|
classes=unique(imageClassList);
|
|
if classes(1)==0, classes=classes(2:end); end %class 0 is background - no voting required
|
|
for c=classes
|
|
indImagesInClass=find(imageClassList==c); %all images in class
|
|
indPatchesInClass=find(patchClassList>0); %all non-background patches
|
|
|
|
%initialise voting weights for each codebook entry
|
|
for j=1:length(indPatchesInClass)
|
|
locations{c}{j}=single(0);
|
|
end
|
|
|
|
totalLocations=0;
|
|
for i=1:length(indImagesInClass)
|
|
%extract keypoints from training image
|
|
keypoints=extract_keypoints(I{indImagesInClass(i)},numPatchesPerImage);
|
|
|
|
[a,b]=size(I{indImagesInClass(i)});
|
|
|
|
for k=1:size(keypoints,1);
|
|
%compare patch around keypoint to the patches in codebook
|
|
Ipatch=extract_patch(I{indImagesInClass(i)},hlen,keypoints(k,1:2));
|
|
|
|
if ~isnan(Ipatch)
|
|
similarity=distance_measure(Ipatch(:)',patches,1);
|
|
%[matchingStrength,matchingCluster]=max(similarity);
|
|
[matchingStrength,matchingCluster]=sort(similarity,'descend');
|
|
for m=1:maxMatches
|
|
%if patch matches the codebook record as a vote in that patch's voting array
|
|
if patchClassList(matchingCluster(m))>0 && matchingStrength(m)>0 %similarityThres
|
|
%calculate position of object centre relative to patch
|
|
voteLocation=objLocation(indImagesInClass(i),:)-keypoints(k,1:2);
|
|
%add weight at this offset from the centre of the vote weight array
|
|
matchedClassCluster=find(indPatchesInClass==matchingCluster(m));
|
|
locations{c}{matchedClassCluster}=...
|
|
expanding_location_update(locations{c}{matchedClassCluster},voteLocation);
|
|
totalLocations=totalLocations+1;
|
|
else
|
|
break;%jump out of for-loop the first time the matching criteria is failed
|
|
end
|
|
end
|
|
end
|
|
end
|
|
end
|
|
disp(['extracted ',int2str(totalLocations),' locations for class ',int2str(c)])
|
|
|
|
%%count the number of votes corresponding to each cluster
|
|
%for j=1:length(indPatchesInClass)
|
|
% len(j)=sum(locations{c}{j}(:));
|
|
%end
|
|
%len
|
|
end
|
|
|
|
|
|
function L=expanding_location_update(L,voteOffset)
|
|
%calculate location of vote relative to centre of vote weight array
|
|
[A,B]=size(L);
|
|
voteLocation=ceil([A,B]./2+voteOffset);
|
|
|
|
%expand vote weight array if new vote location is beyond current borders
|
|
overhangs(1)=max(0,max([1-voteLocation(1),voteLocation(1)-A]));
|
|
overhangs(2)=max(0,max([1-voteLocation(2),voteLocation(2)-B]));
|
|
L=padarray(L,overhangs,0,'both');
|
|
|
|
%recalculate vote position, and add vote
|
|
[A,B]=size(L);
|
|
voteLocation=ceil([A,B]./2+voteOffset);
|
|
L(voteLocation(1),voteLocation(2))=max(L(voteLocation(1),voteLocation(2)),1);
|