You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

162 lines
5.5 KiB
Matlab

function plot_network(x,r,y,partition,values,partitionLabels,proportionalLength)
numStages=length(y);
%decide on y-axis scales, they are all the same so only need to label first ones
ymax=0;
for s=1:numStages
ymax=max(ymax,max(y{s}));
end
ystep=0.1; if ymax>=0.4, ystep=0.2; end, if ymax>=1, ystep=0.5; end
ytop=max(0.1,1.05.*ymax);
xtop=1.08;
rtop=1.08;
rc=[0,0.65,0];%color for reconstruction neurons
spacing=0.1; %spacing between axes
xshift=0.01./numStages;
endgap=0.5; %distance between edge of histogram and edge of axes
%decide whether to plot all partitions (in a stage) the same size, or have the size
%proportition to the number of inputs in that partition
if nargin<7 || isempty(proportionalLength)
proportionalLength=0;
end
%plot neural activities in each stage
for s=1:numStages
%count the number of partitions in this stage
numPartitions=0;
for p=1:size(partition,2)
if ~isempty(partition{s,p})
numPartitions=numPartitions+1;
end
end
%plot prediction neuron responses
pos=maxsubplot(3,numStages,numStages+s,spacing/numPartitions);
set(gca,'Position',pos+[xshift,+0.05,0,-0.09]); %squeeze vertically, to make room for labels on axes above and below
bar(y{s},1,'r','EdgeColor','r'),
axis([0.5,length(y{s})+0.5,0,ytop])
set(gca,'XTick',[],'YTick',[0:ystep:1]);
if s>1, set(gca,'YTick',[]); end
text(0.01,1,['y^{S',int2str(s),'}'],'Units','normalized','color','r','VerticalAlignment','top');
for p=1:numPartitions
xtmp=x{s}(partition{s,p},1);
rtmp=r{s}(partition{s,p},1);
%plot inputs
if proportionalLength
pos=maxsubplot(3,numStages,s,spacing/numPartitions);
axstart=pos(1);
axlength=pos(3)-((numPartitions-1)*spacing/4);
scalefac=axlength./(numPartitions*(2*endgap)+length(x{s}));
%the new starting position:
pos(1)=axstart+((p-1)*(spacing/4))+(partition{s,p}(1)-1+(p-1)*(2*endgap))*scalefac;
%the new length:
pos(3)=(length(partition{s,p})+(2*endgap))*scalefac;
set(gca,'Position',pos);
else
pos=maxsubplot(3,numStages*numPartitions,(s-1)*numPartitions+p,spacing);
end
set(gca,'Position',pos+[xshift,+0.03,0,-0.04]);%squeeze vertically, to make room for labels
bar([1:length(partition{s,p})],xtmp,1,'k'),
axis([0.5-endgap,length(xtmp)+0.5+endgap,0,xtop]);
[xticks,xtolabel]=rescale_axes(axis,values{s,p});
plot_decode(xtmp,values{s,p});
set(gca,'XTick',xticks,'XTickLabel',num2str(xtolabel','%+d'));
if p>1 || s>1, set(gca,'YTick',[]); end
if exist('partitionLabels') & ~isempty(partitionLabels),
text(0.01,1,partitionLabels{s,p},'Units','normalized','color','k','VerticalAlignment','top');
end
%plot reconstruction neuron responses
if proportionalLength
pos=maxsubplot(3,numStages,2*numStages+s,spacing/numPartitions);
axstart=pos(1);
axlength=pos(3)-((numPartitions-1)*spacing/4);
scalefac=axlength./(numPartitions*(2*endgap)+length(x{s}));
%the new starting position:
pos(1)=axstart+((p-1)*(spacing/4))+(partition{s,p}(1)-1+(p-1)*(2*endgap))*scalefac;
%the new length:
pos(3)=(length(partition{s,p})+(2*endgap))*scalefac;
set(gca,'Position',pos);
else
pos=maxsubplot(3,numStages*numPartitions,2*numStages*numPartitions+(s-1)*numPartitions+p,spacing);
end
set(gca,'Position',pos+[xshift,+0.04,0,-0.05]);%squeeze vertically, to make room for labels
bar([1:length(partition{s,p})],rtmp,1,'FaceColor',rc,'EdgeColor',rc),
axis([0.5-endgap,length(rtmp)+0.5+endgap,0,rtop]);
[xticks,xtolabel]=rescale_axes(axis,values{s,p});
plot_decode(rtmp,values{s,p});
set(gca,'XTick',xticks,'XTickLabel',num2str(xtolabel','%+d'));
if p>1 || s>1, set(gca,'YTick',[]); end
if exist('partitionLabels') & ~isempty(partitionLabels),
text(0.01,1,partitionLabels{s,p},'Units','normalized','color',rc,'VerticalAlignment','top');
end
end
end
cmap=colormap('gray');%cmap=1-cmap;colormap(cmap);
set(gcf,'PaperSize',[9*numStages 7],'PaperPosition',[0.5 0.3 9*numStages-0.5 6.5],'PaperOrientation','Portrait');
drawnow;
function [xticks,xtolabel]=rescale_axes(axrange,xrange)
s=abs(xrange(ceil(length(xrange)/7)));
s=10*floor(s/10);
xtolabel=[-s:s:s];
xticks=axrange(1)+(xtolabel-xrange(1))./(xrange(end)-xrange(1)).*(axrange(2)-axrange(1));
if isempty(xtolabel)
xtolabel=[1:length(axrange(1))];
xticks=xrange;
end
function plot_decode(z,s,integrate,labelSigma)
%add label to plot to indicate mean of distribution
if nargin<3 || isempty(integrate), integrate=0; end
if isunimodal(z) %only do so if distibution is unimodal (mean is meaningless otherwise)
[mu,var]=decode(z',s,integrate);
muposn=1+((mu-min(s)).*(length(s)-1)./(max(s)-min(s)));
hold on
plot(muposn.*[1,1],[0,100],'k-'),
plot(muposn.*[1,1],[0,100],'w--');
ax=axis;
if nargin>3 && labelSigma
text('Position',[muposn,ax(4)*0.97],'String',[num2str(mu,'%+3.1f\n'),' (',num2str(sqrt(var),'%3.1f\n'),')'],'HorizontalAlignment','center','VerticalAlignment','Bottom','FontWeight','bold','FontSize',11);
else
text('Position',[muposn,ax(4)*0.97],'String',[num2str(mu,'%+3.1f\n')],'HorizontalAlignment','center','VerticalAlignment','Bottom','FontWeight','bold','FontSize',11);
end
end
function val=isunimodal(z)
zleft=[0;z(1:end-1)];
zright=[z(2:end);0];
peaks=min(z-zleft,z-zright); %nonnegative only for values bigger than both neightbours
sup=peaks<0;
z(sup)=0;
posn=find(z>0.05.*max(z));
if length(posn)==0
val=0;
else
if length(posn)==1 || posn(end)-posn(1)<3
val=1;
else
val=0;
end
end