You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

120 lines
2.6 KiB
Matlab

function plot_network(x,W,y,labely,labelx,a1,a2,rotate)
% Plot a schematic of a neural network
m=length(x);
n=length(y);
noffset=0;
moffset=0;
if nargin<4
labely=0;
end
if nargin<5
labelx=1;
end
if nargin<7
a1=zeros(m,1);
a2=zeros(n,1);
end
if nargin<8
rotate=0;
end
if m>n
noffset=0.5*(m-n);
end
if n>m
moffset=0.5*(n-m);
end
units=get(gca,'Units');
set(gca,'Units','points');
loc=get(gca,'Position');
if rotate
ptspernode=min([loc(4)./(max([m,n])+1),loc(3)./2]);
else
ptspernode=min([loc(3)./(max([m,n])+1),loc(4)./2]);
end
markersize=fix(ptspernode*0.8);
fntsize=max(11,fix(ptspernode/4));
set(gca,'Units',units);
%PLOT WEIGHTS
for i=1:m
for j=1:n
ipt=[i+moffset,j+noffset];
jpt=[0,1];
if rotate, [ipt,jpt]=swap(ipt,jpt); end
if W(i,j)>0.01;
plot(ipt,jpt,'b-','LineWidth',4*W(i,j));
end
if W(i,j)<-0.01;
plot(ipt,jpt,'c--','LineWidth',-4*W(i,j));
end
hold on
end
end
%SHOW ACTIVATIONS
%don't label tiny activation values
x(find(x<0.01))=0;
y(find(y<0.01))=0;
for i=1:m
ipt=i+moffset;
jpt=0;
if rotate, [ipt,jpt]=swap(ipt,jpt); end
if labelx
%label inputs
text(ipt,jpt-0.10,num2str(x(i),2),'HorizontalAlignment','center','VerticalAlignment','top','FontSize',fntsize)
else
%draw inputs
plot(ipt,jpt,'ko','MarkerSize',markersize,'MarkerFaceColor','w');
if x(i)>0.0,
plot(ipt,jpt,'go','MarkerSize',x(i)*markersize,'MarkerFaceColor','g');
end
if a1(i)>0.0,
plot_attn_arrow(ipt,jpt,m,n,rotate);
end
end
end
for j=1:n
ipt=j+noffset;
jpt=1;
if rotate, [ipt,jpt]=swap(ipt,jpt); end
plot(ipt,jpt,'ko','MarkerSize',markersize,'MarkerFaceColor','w');
if y(j)>0.0,
plot(ipt,jpt,'go','MarkerSize',y(j)*markersize,'MarkerFaceColor','g');
end
if labely
text(ipt,jpt+markersize/80,num2str(y(j),2),'HorizontalAlignment', ...
'center','VerticalAlignment','bottom','FontSize',fntsize);
end
if a2(j)>0.0,
plot_attn_arrow(ipt,jpt,m,n,rotate);
end
end
if rotate, axis([-0.5,1.5,-max([m,n])-0.5,-0]),
else axis([0.5,max([m,n])+0.5,-0.5,1.5]), end
axis off
function [b,a]=swap(a,b)
a=-a;
function plot_attn_arrow(ipt,jpt,m,n,rotate)
equalarrowaspect=2./4;%max([m,n]);
if rotate
ia=ipt+0.25*equalarrowaspect;
ja=jpt+0.25;
l=0.3;
plot([ia,ia+0.5*equalarrowaspect],[ja,ja+0.5],'r-','LineWidth',2);
plot([ia-0.05,ia+l*equalarrowaspect],[ja,ja],'r-','LineWidth',2);
plot([ia,ia],[ja-0.05,ja+l],'r-','LineWidth',2);
else
ia=ipt-0.2;
ja=jpt+0.2*equalarrowaspect;
plot([ia,ia-0.5],[ja,ja+0.5*equalarrowaspect],'r-','LineWidth',2);
plot([ia,ia],[ja,ja+0.3*equalarrowaspect],'r-','LineWidth',2);
plot([ia,ia-0.3],[ja,ja],'r-','LineWidth',2);
end