You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

324 lines
8.5 KiB
Matlab

function dim_squares()
%set network parameters
beta=0.05; %learning rate
iterations=50; %number of iterations to calculate y (for each input)
epsilon=1e-10;
%define task
p=6; %length of one side of input image
s=3; %size of square image components
n=48; %number of nodes
m=p*p; %number of inputs
cycs=20000; %number of training cycles
show=1000; %how often to plot receptive field data
patterns=1000; %number of training patterns in training set
numsquares=(p-s+1).^2;
probs=0.1*ones(1,numsquares);
mincontrast=1;
%probs=0.02+0.18*rand(1,numsquares);
%mincontrast=0.1;
%generate a fixed pattern set
clear data
for k=1:patterns
data(:,k)=squares_pattern_randprob(p,s,probs,mincontrast);
end
%define initial weights
W=(1/16)+(1/64).*randn(n,m);%Gassian distributed weights with given
%mean and standard deviation
W(find(W<0))=0;
%learn receptive fields
for k=1:cycs
if rem(k,1000)==0, fprintf(1,'.%i.',k); end
patternNum=fix(rand*patterns)+1; %random order
x=data(:,patternNum);
What=W./(epsilon+(max(W')'*ones(1,m)));%weights into nodes normalised by
%maximum value
%iterate to calculate node activations
y=zeros(n,1);
for i=1:iterations
e=x./(epsilon+(What'*y));
y=(epsilon+y).*(W*e);
end
%update weights
W=W.*( 1 + beta.*( y*(e'-1) ));
W(find(W<0))=0;
if rem(k,show)==0,
%show weights
squares_plot(s,W);
end
end
s=sum(W'), disp(num2str([max(s),min(s),max(max(W)),min(min(W))]))
disp('');
function [x,patterns,input_set_components]=squares_pattern_randprob(m,s,prob,mincontrast)
%function [x,patterns,input_set_components]=squares_pattern_randprob(m,s,prob,mincontrast)
%
%create a mxm pixel image in which overlapping sxs squares are randomly
%active with probability 'prob', where prob is a vector defining the
%independent probability for each separate componenet.
%The contast of each component is assigned randomly (between mincontrast and
%1) in each generated pattern.
nsquares=(m-s+1);
%choose one square to be present, so that each pattern will contain at least one
%square. Need to make choice based on probability of each component being
%present.
c=rand*sum(prob);
included=min(find(c<cumsum(prob)));
depthorder=randperm(nsquares^2);%randomly assign a depth to each possible
%component to decide which contrast goes on top
%randomly assign a contrast between mincontrast and 1 to each possible component
contrast=mincontrast+(1-mincontrast).*rand(1,nsquares^2);
%add patterns to input
npattern=0;
patterns=[];
x=zeros(m,m);
depth=zeros(m,m);
for c=1:nsquares
for r=1:nsquares
npattern=npattern+1;
if (rand<prob(npattern) | npattern==included) & contrast(npattern)>0
%decide at which pixels current component is infront
depth(r:r+s-1,c:c+s-1)=max(depth(r:r+s-1,c:c+s-1), ...
depthorder(npattern));
%fill in those pixels with contrast of current component
for j=r:r+s-1
for i=c:c+s-1
if depth(j,i)==depthorder(npattern)
x(j,i)=contrast(npattern);
end
end
end
%record fact that this component has been selected
patterns=[patterns,npattern];
end
end
end
%remove any patterns from the list of included patterns that are entirely
%occluded by other patterns
npattern=0;
occludedpatterns=[];
for c=1:nsquares
for r=1:nsquares
npattern=npattern+1;
if ismember(npattern,patterns)
%see if this pattern is ontop at any pixel
if ~ismember(depthorder(npattern),depth(r:r+s-1,c:c+s-1));
occludedpatterns=[occludedpatterns,npattern];
end
end
end
end
patterns=setdiff(patterns,occludedpatterns);
x=x(:);
if sum(x)==0
disp('missing');
[x,patterns,input_set_components]=squares_pattern(m,s,prob,noise);
end
input_set_components=zeros(1,nsquares^2);
input_set_components(patterns)=1;
function [nrepanycomplete]=squares_plot(sqsize,weights)
%function [nrepanycomplete]=squares_plot(sqsize,weights)
if nargin<2
weights=load('weights_1basal.dat');
end
[n,m]=size(weights);
p=sqrt(m);
scale=max(max(weights))*0.85;
plot_per_row=min(n,8);
num_rows=ceil(n/plot_per_row);
clf
sqsize=sort(sqsize);
nsqsizes=length(sqsize);
totalsquares=0;
for k=1:nsqsizes
s=sqsize(k);
totalsquares=totalsquares+(p-s+1)^2;
end
representedcomplete=zeros(1,totalsquares);
for j=1:n
w=reshape(weights(j,:),p,p);
subplot(num_rows,plot_per_row,j),hinton_plot(w,scale,3,1,1);
%determine degree of match between weights and all possible input patterns
npattern=0;
lpref=[];
for k=1:nsqsizes
s=sqsize(k);
nsquares=(p-s+1);
for c=1:nsquares
for r=1:nsquares
npattern=npattern+1;
wOut=w; wOut(r:r+s-1,c:c+s-1)=0;
wIn=w(r:r+s-1,c:c+s-1);
if sum(sum(wIn))>3*sum(sum(max(0,wOut))) & ...
min(min(wIn))>max(max(wOut)) & ...
min(min(wIn))>mean(mean(max(0,w)))
lpref=[lpref,npattern];
end
end
end
end
if length(lpref)>1
lpref
end
representedcomplete(lpref)=representedcomplete(lpref)+1;
if representedcomplete(lpref)>1,
text(1,0.25,['(',int2str(lpref),')']);
else
text(1,0.25,int2str(lpref));
end
end
nrepanycomplete=length(find(representedcomplete>0));
disp(['network represents ', int2str(nrepanycomplete), ' patterns ']);
norep=find(representedcomplete==0);
if length(norep>1) disp(['FAILED to represent pattern = ',int2str(norep)]); end
function hinton_plot(W, scale, colour, type, equal)
%function hinton_plot(W, scale, colour, type, equal)
% type 0 = variable size boxes (size relates to strength)
% type 1 = image (color intensity relates to strength)
% type 2 = equal size squares (color intensity relates to strength)
% type 3 = same as type 0 but with outerboarder showing maximum size of box
W(find(W<0))=0;
if (type==1)
%draw as an image: strength indicated by pixel darkness
%W is true data value (greater than 0) and is scaled to be between 0 and 255
imagesc(uint8(round((W./scale)*255)),[0,255]);%,'CDataMapping','scaled'),
colormap(gray)
map=colormap;
map=flipud(map);
map(1:64,colour)=map(1:64,colour)*0.0+1;%*0.33+0.66;
colormap(map)
axis on
if(equal==1), axis equal, end
if(equal==1), axis tight, end
elseif (type==0 | type==3)
%draw as squares: strength indicated by size of square
colstr=['r','g','b'];
if(equal==0)
%calc aspect ratio - if not going to set axis equal
plot(size(W,2)+0.5,size(W,1)-0.5,'bx');
hold on
plot(0.5,-0.5,'bx');
axis equal
a=axis;
aspectX=size(W,2)/abs(a(2)-a(1));
aspectY=size(W,1)/abs(a(4)-a(3));
aspectXX=aspectX./max(aspectX,aspectY);
aspectYY=aspectY./max(aspectX,aspectY);
hold off
else
aspectXX=1;
aspectYY=1;
end
for i=1:size(W,2)
for j=1:size(W,1)
box_widthX=aspectXX*0.5*W(j,i)/(scale);
box_widthY=aspectYY*0.5*W(j,i)/(scale);
h=fill([i-box_widthX,i+box_widthX,i+box_widthX,i-box_widthX],size(W,1)+1-[j-box_widthY,j-box_widthY,j+box_widthY,j+box_widthY],colstr(colour));
hold on
if (isnan(W(j,i)))
plot(i,size(W,1)+1-j,'kx','MarkerSize',20);
end
if (type==3)
set(h, 'EdgeColor','w');
box_widthX=aspectXX*0.5*1;
box_widthY=aspectYY*0.5*1;
h=fill([i-box_widthX,i+box_widthX,i+box_widthX,i-box_widthX],size(W,1)+1-[j-box_widthY,j-box_widthY,j+box_widthY,j+box_widthY],'w','FaceAlpha',0);
end
end
end
%axis off
%axis tight
if(equal==1), axis equal, end
axis([0.5,size(W,2)+0.5,+0.5,size(W,1)+0.5])
else
%draw as equal sized squares: strength indicated by darkness of square
if(equal==0)
%calc aspect ratio - if not going to set axis equal
plot(size(W,2)+0.5,size(W,1)-0.5,'bx');
hold on
plot(0.5,-0.5,'bx');
axis equal
a=axis;
aspectX=size(W,2)/abs(a(2)-a(1));
aspectY=size(W,1)/abs(a(4)-a(3));
aspectXX=aspectX./max(aspectX,aspectY);
aspectYY=aspectY./max(aspectX,aspectY);
hold off
else
aspectXX=1;
aspectYY=1;
end
box_widthX=aspectXX*0.33;
box_widthY=aspectYY*0.33;
for i=1:size(W,2)
for j=1:size(W,1)
fill([i-box_widthX,i+box_widthX,i+box_widthX,i-box_widthX],size(W,1)+1-[j-box_widthY,j-box_widthY,j+box_widthY,j+box_widthY],ones(1,4).*round((W(j,i)./scale)*255),'FaceColor','flat');
hold on
end
end
colormap(gray)
map=colormap;
map=flipud(map);
map(1:64,colour)=map(1:64,colour)*0.0+1;%*0.33+0.66;
colormap(map)
caxis([0,255])%if we remove this then each subplot is scaled independently
%axis off
%axis tight
if(equal==1), axis equal, end
axis([0.5,size(W,2)+0.5,+0.5,size(W,1)+0.5])
end
%set(gca,'YTickLabel',[' ';' ';' ';' ';' ';' ';' '])
%set(gca,'XTickLabel',[' ';' ';' ';' ';' ';' ';' '])
set(gca,'YTick',[])
set(gca,'XTick',[])
drawnow