You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

395 lines
15 KiB
Matlab

function ismdim_training(task)
% Performs the training process of the implicit shape model (ISM).
% The input specifies the training image set to use.
% The output is a .mat file containing the codebook for the specified training image set.
% The method differs from the standard in the way patches are matched to the
% image (this method uses DIM)
if nargin<1 || isempty(task), task='cars'; end
%LOAD TRAINING IMAGES
[TrainingImages,imageClassList,objLocation,objRegion]=load_dataset(task);
%EXTRACT PATCHES FROM AROUND KEYPOINTS IN THE TRAINING IMAGES
patchHalfLen=7;
sigmaLGN=patchHalfLen/2;
[patches,patchClassList]=extract_patches_all(TrainingImages,imageClassList,objRegion,patchHalfLen,sigmaLGN);
%CLUSTER PATCHES TO FORM CODEBOOK
patchClustering='agglom';%'dim'; %
switch patchClustering
case 'dim'
numClusters=[150,150];
%agglom version of carsScale: [2465, 193, (273), 444, 651, 877, 1163]
%agglom version of pedestrians: [4799, 53, 121, 191, (241), 310, 393]
%numClusters=[150,150,150,150,150,150,150]
numClusters=[900,150,150,150,150,150,150]
%numClusters=[1500,75,150,225,300,375,450]
%numClusters=[1050,50,100,150,200,250,300]
%numClusters=[1000,1000]
[patches,patchClassList]=cluster_patches_dim(patches,patchClassList,numClusters);
case 'agglom'
numClusterMembers=12;
similarityThres=0.4;
[patches,patchClassList]=cluster_patches_agglom(patches,patchClassList,numClusterMembers,similarityThres);
end
%FOR EACH CODEBOOK ENTRY LOCATE CENTRE OF OBJECT RELATVE TO PATCH POSITION
Locations=extract_locations(TrainingImages,imageClassList,objLocation,objRegion,patches,patchClassList,sigmaLGN);
%SAVE CODEBOOK TO DISK
%filename=['ISMDIM_codebook_',task,'_nodiffboundary_400imagesandflipped_Linregion.mat']
filename=['ISMDIM_codebook_',task,'_nodiffboundary_flipped0_Linregion_nocrop_e1new.mat']
save(filename,'-v7.3','patches','patchClassList','Locations','sigmaLGN')
%show codebook entries: patches and vote positions
figured(1),clf,plot_codebook(patches,patchClassList,Locations);
function [TrainingImages,classList,objLocation,objRegion]=load_dataset(task);
%training images should contain one object only
TrainingImages=single([]);classList=single([]);objLocation=single([]);objRegion=single([]);
switch task
case 'cars'
%background images
numBackgnd=500;
flipped=0
for i=1:numBackgnd
TrainingImages{i}=load_image_car(i-1,2);
classList(i)=0;
objLocation(i,1:2)=[NaN;NaN];
objRegion{i}=zeros(size(TrainingImages{i}),'single');
end
%car images
for i=1:550
TrainingImages{numBackgnd+i}=load_image_car(i-1,1);
classList(numBackgnd+i)=1;
objLocation(numBackgnd+i,1:2)=size(TrainingImages{numBackgnd+i})./2;%assumes all objects are centred in the training images.
objRegion{numBackgnd+i}=ones(size(TrainingImages{numBackgnd+i}),'single');
end
if flipped
numImages=length(classList);
for i=1:numImages
TrainingImages{numImages+i}=fliplr(TrainingImages{i});
classList(numImages+i)=classList(i);
objLocation(numImages+i,1)=objLocation(i,1);
objLocation(numImages+i,2)=size(TrainingImages{i},2)-objLocation(i,2)+1;
objRegion{numImages+i}=fliplr(objRegion{i});
end
end
case 'carsScale'
%background images
scales=1.2.^[-1:4]
numBackgnd=500;
for i=1:numBackgnd
I=load_image_car(i-1,2);
[TrainingImages,classList,objLocation,objRegion]=expand_image_scales(scales,0,i,TrainingImages,classList,objLocation,objRegion,I,[],0);
end
%car images
for i=1:550
I=load_image_car(i-1,1);
[TrainingImages,classList,objLocation,objRegion]=expand_image_scales(scales,numBackgnd,i,TrainingImages,classList,objLocation,objRegion,I,[],1);
end
case 'pedestrians'
%background images
scales=1.34.^[-3:2]
numTrain=400; %210;
numBackgnd=numTrain; %270;
flipped=1
for i=1:numBackgnd
I=load_image_pedestrian(i,2);
[TrainingImages,classList,objLocation,objRegion]=expand_image_scales(scales,0,i,TrainingImages,classList,objLocation,objRegion,I,[],0);
end
for i=1:numTrain
[I,groundtruth]=load_image_pedestrian(i,1,numTrain);
[TrainingImages,classList,objLocation,objRegion]=expand_image_scales(scales,numBackgnd,i,TrainingImages,classList,objLocation,objRegion,I,groundtruth,1);
end
if flipped
numImages=length(classList);
for i=1:numImages
TrainingImages{numImages+i}=fliplr(TrainingImages{i});
classList(numImages+i)=classList(i);
objLocation(numImages+i,1)=objLocation(i,1);
objLocation(numImages+i,2)=size(TrainingImages{i},2)-objLocation(i,2)+1;
objRegion{numImages+i}=fliplr(objRegion{i});
end
end
case 'horses'
numBackgnd=100;
for i=1:numBackgnd
%I=load_image_car(i-1,2);
I=load_image_101(i);
%[TrainingImages,classList,objLocation,objRegion]=expand_horse_image_set(0,i,TrainingImages,classList,objLocation,objRegion,I,[],0);
%[TrainingImages,classList,objLocation,objRegion]=expand_affine_image_set(0,i,TrainingImages,classList,objLocation,objRegion,I,[],0);
TrainingImages{i}=I;
classList(i)=0;
objLocation(i,1:2)=[NaN;NaN];
objRegion{i}=ones(size(TrainingImages{i}),'single');
end
%horse images (use first 100 for training, next 228 will be for testing)
for i=1:100
[I,groundtruth]=load_image_horse(i);
%[TrainingImages,classList,objLocation,objRegion]=expand_horse_image_set(numBackgnd,i,TrainingImages,classList,objLocation,objRegion,I,groundtruth,1);
%[TrainingImages,classList,objLocation,objRegion]=expand_affine_image_set(numBackgnd,i,TrainingImages,classList,objLocation,objRegion,I,groundtruth,1);
TrainingImages{numBackgnd+i}=I;
classList(numBackgnd+i)=1;
objLocation(numBackgnd+i,1:2)=fliplr(centroid(groundtruth));
objRegion{numBackgnd+i}=single(imdilate(groundtruth,strel('disk',4)));
end
case 'USPS'
[images,classes,trainingIndeces,testingIndeces,numClasses]=load_data_usps;
numBackgnd=0;
for i=trainingIndeces
I=preprocess_usps_image(images(i,:));
%[TrainingImages,classList,objLocation,objRegion]=expand_affine_image_set(numBackgnd,i,TrainingImages,classList,objLocation,objRegion,I,I,classes(i));
TrainingImages{numBackgnd+i}=I;
classList(numBackgnd+i)=classes(i);
objLocation(numBackgnd+i,1:2)=size(TrainingImages{numBackgnd+i})./2;%assumes all objects are centred in the training images.
objRegion{numBackgnd+i}=ones(size(TrainingImages{numBackgnd+i}),'single');%single(imdilate(I,strel('disk',4)));
end
case 'NORB'
[images,classes,numClasses]=load_data_norb(1);
numBackgnd=0;
for i=1:length(classes)
TrainingImages{numBackgnd+i}=reshape(images(i,:),96,96)';
classList(numBackgnd+i)=classes(i);
objLocation(numBackgnd+i,1:2)=size(TrainingImages{numBackgnd+i})./2;%assumes all objects are centred in the training images.
objRegion{numBackgnd+i}=ones(size(TrainingImages{numBackgnd+i}),'single');
end
otherwise
disp('ERROR: unknown data set');
end
disp(['loaded ',int2str(length(TrainingImages)),' training images']);
function [patches,patchClassList]=extract_patches_all(I,imageClassList,objRegion,patchHalfLen,sigmaLGN)
numImages=length(I);
numPatchesPerImage=200;
%pre-allocate (hopefully) too much memory
estNumPatchesTotal=numImages*numPatchesPerImage;
patchLen=1+2*patchHalfLen;
patches=zeros(estNumPatchesTotal,patchLen^2,'single');
patchClassList=zeros(1,estNumPatchesTotal,'single');
j=0;
for i=1:numImages
[Ipatches,keypoints]=extract_patches(I{i},patchHalfLen,numPatchesPerImage,'corner');%,[],sigmaLGN);
for k=1:size(keypoints,1);
%keep record of patch and class of image it came from
if imageClassList(i)==0 || objRegion{i}(keypoints(k,1),keypoints(k,2))>0.5
j=j+1;
patches(j,:)=Ipatches(k,:);
patchClassList(j)=imageClassList(i);
end
end
end
%remove excess elements
patches=patches(1:j,:);
patchClassList=patchClassList(1:j);
disp(['extracted ',int2str(j),' image patches (including ',int2str(sum(patchClassList==0)),' for class 0)'])
function [clustersAllClasses,clusterClassList]=cluster_patches_dim(patches,patchClassList,n);
[numPatches,patchLen]=size(patches);
batchLen=1 %length(inTrain) %number of training patterns used in each learning batch
beta=0.1;
alpha=1;
clustersAllClasses=single([]);
clusterClassList=single([]);
%patchClassList(patchClassList>1)=1; %cluster all non-background patches together
%extract class-specific patches
classes=unique(patchClassList);
for c=classes
indPatchesInClass=find(patchClassList==c);
patchesInClass=patches(indPatchesInClass,:);
[numPatches,patchLen]=size(patchesInClass);
[W,V]=weight_initialisation_random(n(c+1),patchLen);
cycs=max(5*numPatches,fix(100*n(c+1)/sqrt(batchLen))) %number of training cycles
for cyc=1:cycs
%choose a batch of input stimuli to train on
order=randperm(numPatches);
%calculate responses to training data
[Y,E]=dim_activation(W,patchesInClass(order(1:batchLen),:)',V);
%update weights
[W,V]=dim_learn(W,V,Y,E,beta,alpha);
end
clustersAllClasses=[clustersAllClasses;V];
clusterClassList=[clusterClassList,c.*ones(1,n(c+1),'single')];
disp(['extracted ',int2str(size(V,1)),' clusters for class ',int2str(c)])
end
disp(['extracted ',int2str(size(clustersAllClasses,1)),' clusters for all classes'])
function [clustersAllClasses,clusterClassList]=cluster_patches_agglom(patches,patchClassList,numClusterMembersReqd,similarityThres);
clustersAllClasses=single([]);
clusterClassList=single([]);
%extract class-specific patches
classes=unique(patchClassList);
for c=classes
indPatchesInClass=find(patchClassList==c);
%cluster patches, to find cluster centres
[clusters,len]=agglomerative_clustering(patches(indPatchesInClass,:),similarityThres);
numClusters=size(clusters,1)
%keep only those clusters that contain the most exemplars - i.e. the most frequently encountered patches
keep=find(len>numClusterMembersReqd);
clustersAllClasses=[clustersAllClasses;clusters(keep,:)];
clusterClassList=[clusterClassList,c.*ones(1,length(keep),'single')];
disp(['extracted ',int2str(length(keep)),' clusters for class ',int2str(c)])
end
disp(['extracted ',int2str(size(clustersAllClasses,1)),' clusters for all classes'])
function [clusters,len]=agglomerative_clustering(data,similarityThres)
%Performs hierarchical agglomerative clustering by calling the built-in matlab
%function "clusterdata". However, this method of clustering is very memory
%hungry due to the calculation of the similarity of every pair of
%samples. Hence, if there are lots of samples, clustering is performed multiple
%times on non-overlapping subsets of samples and the results are merged - a kludge!
maxNumPerClustering=40000; %machine dependent!
numSamplesTotal=size(data,1);
numSamplesPerClustering=numSamplesTotal;
repeats=0;
while numSamplesPerClustering>maxNumPerClustering
numSamplesPerClustering=ceil(numSamplesPerClustering/2);
repeats=repeats+1;
end
numSamplesPerClustering;
repeats=2^repeats;
if repeats>1,
disp(['WARNING: due to memory limitations APPROXIMATE clustering being performed using ',...
int2str(repeats),' iterations']);
end
clusters=[];
len=[];
for it=1:repeats
indToCluster=1+(it-1)*numSamplesPerClustering:min(it*numSamplesPerClustering,numSamplesTotal);
patchesToCluster=data(indToCluster,:);
clusterIndex = clusterdata(patchesToCluster,'cutoff',1-similarityThres,'criterion','distance','linkage','complete','distance',@distance_measure);%note clustering is based on distance, not similarity, so need to use 1-similarity threshold
%place all patches into their clusters
numClusters=max(clusterIndex)
newClusters=single([]);newLen=[];
for i=1:numClusters
indClusterMembers=find(clusterIndex==i);
newClusters(i,:)=mean(patchesToCluster(indClusterMembers,:),1); %each cluster is mean of all members
newLen(i)=length(indClusterMembers);
end
%remove clusters that are similar to those already found on previous iterations
if it>1
simExisting=zeros(1,numClusters);
for i=1:numClusters
%find similarity of closest patch in exiting clusters
[simExisting(i),indExisting(i)]=max(distance_measure(newClusters(i,:),clusters,1));
end
addto=find(simExisting>=0.97);
len(indExisting(addto))=len(indExisting(addto))+newLen(addto);
keep=find(simExisting<0.97);
newClusters=newClusters(keep,:);
newLen=newLen(keep);
end
%add remaining clusters to those found on previous iterations
clusters=[clusters;newClusters];
len=[len,newLen];
end
function locations=extract_locations(I,imageClassList,objLocation,objRegion,patches,patchClassList,sigmaLGN)
%defines the locations where votes will be cast relative to the patch/keypoint
[numTemplates,patchLen]=size(patches)
patchLen=sqrt(patchLen);
patchHalfLen=(patchLen-1)/2;
ONOFF=1;
%CONVERT PATCH CODEBOOK TO DIM WEIGHTS
[w,v]=define_dictionary(patches,ONOFF,sigmaLGN);
ymax=0;
classes=unique(imageClassList);
if classes(1)==0, classes=classes(2:end); end %class 0 is background - no voting required
for c=classes
indImagesInClass=find(imageClassList==c); %all images in class
indPatchesInClass=find(patchClassList>0); %all non-background patches
%initialise voting weights for each codebook entry
A=0;B=0;
for i=1:length(indImagesInClass)
[a,b]=size(I{indImagesInClass(i)});
A=max(A,a);B=max(B,b);
end
A=2*A;B=2*B; %make large enough that votes will not fall off edge
for j=1:length(indPatchesInClass)
locations{c,j}=zeros(A,B,'single');
end
for i=1:length(indImagesInClass)
disp(['extracting locations: class ',int2str(c),' of ',int2str(length(classes)),' / image ',int2str(i),' of ',int2str(length(indImagesInClass))])
[a,b]=size(I{indImagesInClass(i)});
rangeA=round(A/2-objLocation(indImagesInClass(i),1))+[1:a];
rangeB=round(B/2-objLocation(indImagesInClass(i),2))+[1:b];
%MATCH CODEBOOK PATCHES TO IMAGE using DIM
if ONOFF
[~,~,X{1},X{2},trueRange]=imnorm(I{indImagesInClass(i)},sigmaLGN,[],1);
y=dim_activation_conv_recurrent(w,X,v,[],[],trueRange);
else
y=dim_activation_conv_recurrent(w,{I{indImagesInClass(i)}},v,[]);
end
ymax=max(ymax,max(max(max(cat(3,y{:}),[],3))));
%select only those matches that come from voting elements (non-background
%patches) and exceed a threshold
threshold=1e-3;
k=0;
for j=indPatchesInClass
k=k+1;
y{j}=imregionalmax(y{j}).*y{j};
y{j}(y{j}<threshold)=0;
y{j}=y{j}.*objRegion{indImagesInClass(i)}; %only allow votes from patches that touch the object
locations{c,k}(rangeA,rangeB)=locations{c,k}(rangeA,rangeB)+y{j};
end
end
for j=1:length(indPatchesInClass)
%crop to make matching faster, assumes vote locations are not on the periphery
%cropA=round(A/4);
%cropB=round(B/4);
%locations{c,j}=locations{c,j}(cropA+1:A-cropA,cropB+1:B-cropB);
%count the number of votes corresponding to each cluster
sumLocations(j)=sum(locations{c,j}(:));
end
disp_stats(sumLocations,'sumLocations: ')
ymax
end