You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
120 lines
2.6 KiB
Matlab
120 lines
2.6 KiB
Matlab
function plot_network(x,W,y,labely,labelx,a1,a2,rotate)
|
|
% Plot a schematic of a neural network
|
|
|
|
m=length(x);
|
|
n=length(y);
|
|
noffset=0;
|
|
moffset=0;
|
|
|
|
if nargin<4
|
|
labely=0;
|
|
end
|
|
if nargin<5
|
|
labelx=1;
|
|
end
|
|
if nargin<7
|
|
a1=zeros(m,1);
|
|
a2=zeros(n,1);
|
|
end
|
|
if nargin<8
|
|
rotate=0;
|
|
end
|
|
|
|
if m>n
|
|
noffset=0.5*(m-n);
|
|
end
|
|
if n>m
|
|
moffset=0.5*(n-m);
|
|
end
|
|
|
|
units=get(gca,'Units');
|
|
set(gca,'Units','points');
|
|
loc=get(gca,'Position');
|
|
if rotate
|
|
ptspernode=min([loc(4)./(max([m,n])+1),loc(3)./2]);
|
|
else
|
|
ptspernode=min([loc(3)./(max([m,n])+1),loc(4)./2]);
|
|
end
|
|
markersize=fix(ptspernode*0.8);
|
|
fntsize=max(11,fix(ptspernode/4));
|
|
set(gca,'Units',units);
|
|
|
|
%PLOT WEIGHTS
|
|
for i=1:m
|
|
for j=1:n
|
|
ipt=[i+moffset,j+noffset];
|
|
jpt=[0,1];
|
|
if rotate, [ipt,jpt]=swap(ipt,jpt); end
|
|
if W(i,j)>0.01;
|
|
plot(ipt,jpt,'b-','LineWidth',4*W(i,j));
|
|
end
|
|
if W(i,j)<-0.01;
|
|
plot(ipt,jpt,'c--','LineWidth',-4*W(i,j));
|
|
end
|
|
hold on
|
|
end
|
|
end
|
|
|
|
%SHOW ACTIVATIONS
|
|
%don't label tiny activation values
|
|
x(find(x<0.01))=0;
|
|
y(find(y<0.01))=0;
|
|
|
|
for i=1:m
|
|
ipt=i+moffset;
|
|
jpt=0;
|
|
if rotate, [ipt,jpt]=swap(ipt,jpt); end
|
|
if labelx
|
|
%label inputs
|
|
text(ipt,jpt-0.10,num2str(x(i),2),'HorizontalAlignment','center','VerticalAlignment','top','FontSize',fntsize)
|
|
else
|
|
%draw inputs
|
|
plot(ipt,jpt,'ko','MarkerSize',markersize,'MarkerFaceColor','w');
|
|
if x(i)>0.0,
|
|
plot(ipt,jpt,'go','MarkerSize',x(i)*markersize,'MarkerFaceColor','g');
|
|
end
|
|
if a1(i)>0.0,
|
|
plot_attn_arrow(ipt,jpt,m,n,rotate);
|
|
end
|
|
end
|
|
end
|
|
for j=1:n
|
|
ipt=j+noffset;
|
|
jpt=1;
|
|
if rotate, [ipt,jpt]=swap(ipt,jpt); end
|
|
plot(ipt,jpt,'ko','MarkerSize',markersize,'MarkerFaceColor','w');
|
|
if y(j)>0.0,
|
|
plot(ipt,jpt,'go','MarkerSize',y(j)*markersize,'MarkerFaceColor','g');
|
|
end
|
|
if labely
|
|
text(ipt,jpt+markersize/80,num2str(y(j),2),'HorizontalAlignment', ...
|
|
'center','VerticalAlignment','bottom','FontSize',fntsize);
|
|
end
|
|
if a2(j)>0.0,
|
|
plot_attn_arrow(ipt,jpt,m,n,rotate);
|
|
end
|
|
end
|
|
if rotate, axis([-0.5,1.5,-max([m,n])-0.5,-0]),
|
|
else axis([0.5,max([m,n])+0.5,-0.5,1.5]), end
|
|
axis off
|
|
|
|
function [b,a]=swap(a,b)
|
|
a=-a;
|
|
|
|
function plot_attn_arrow(ipt,jpt,m,n,rotate)
|
|
equalarrowaspect=2./4;%max([m,n]);
|
|
if rotate
|
|
ia=ipt+0.25*equalarrowaspect;
|
|
ja=jpt+0.25;
|
|
l=0.3;
|
|
plot([ia,ia+0.5*equalarrowaspect],[ja,ja+0.5],'r-','LineWidth',2);
|
|
plot([ia-0.05,ia+l*equalarrowaspect],[ja,ja],'r-','LineWidth',2);
|
|
plot([ia,ia],[ja-0.05,ja+l],'r-','LineWidth',2);
|
|
|
|
else
|
|
ia=ipt-0.2;
|
|
ja=jpt+0.2*equalarrowaspect;
|
|
plot([ia,ia-0.5],[ja,ja+0.5*equalarrowaspect],'r-','LineWidth',2);
|
|
plot([ia,ia],[ja,ja+0.3*equalarrowaspect],'r-','LineWidth',2);
|
|
plot([ia,ia-0.3],[ja,ja],'r-','LineWidth',2);
|
|
end |