You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

139 lines
4.6 KiB
Matlab

function [Y,E,R]=dim_activation_conv(w,X,v,Y,iterations,trueRange)
% w = a cell array of size {N}, where N is the number of distinct neuron types.
% Each element w{i} is a 3-dimensional matrix. w{i}(:,:,j) is a convolution
% mask specifying the synaptic weights for neuron type i's RF in input
% channel j.
% X = a three dimensional matrix. X(a,b,j) specifies the bottom-up input to
% the DIM neural network at location a,b in channel j.
% v = synaptic weights, defined as for w, but strength normalised differently.
% Y = a three dimensional matrix. Y(a,b,i) specifies the prediction node
% activation for type i neurons at location a,b.
% R = a three dimensional matrix. R(a,b,j) specifies the reconstruction of the
% input at location a,b in channel j.
% E = a three dimensional matrix. E(a,b,j) specifies the error in the
% reconstruction of the input at location a,b in channel j.
% iterations = the number of iterations performed by the DIM algorithm. Default
% is 30.
% trueRange = range of Y, E, and R to keep at the end of processing. Used to try
% to avoid edge effects by having input (X) larger than original image, and
% then cropping edges of outputs, so that outputs are resized to be same
% size as the original image.
[a,b,nInputChannels]=size(X);
[nMasks]=length(w);
[c,d,nChannels]=size(w{1});
if nargin<3 || isempty(v)
%set feedback weights equal to feedforward weights normalized by maximum value for each node
for i=1:nMasks
v{i}=max(0,w{i}./max(1e-6,max(max(max(w{i})))));
end
end
if nargin<4 || isempty(Y)
%initialise prediction neuron outputs to zero
Y=zeros(a,b,nMasks,'single');
end
sumV=0;
for i=1:nMasks
%masks should be an odd size, so that reconstruction isn't shifted with respect
%to the original image
%w{i}=pad_to_make_odd(w{i});
%v{i}=pad_to_make_odd(v{i});
%try to speed things up by removing symmetrical rows/columns of zeros from the
%edges of the weight masks
%w{i}=trimarray_symmetrically(w{i});
%v{i}=trimarray_symmetrically(v{i});
%normalise feedforward weights to sum to one for each node
w{i}=w{i}./max(1e-6,sum(sum(sum(w{i}))));
%rotate feedforward weights so that convolution can be used to apply the
%filtering (otherwise the mask gets rotated every iteration!)
w{i}=rot90(w{i},2);
sumV=sumV+sum(sum(v{i}));
end
%sumV(1:nChannels)=sum(sum(sum(cat(4,v{:}),4),1),2)
%set parameters
epsilon2=1e-2;
epsilon1=epsilon2./max(sumV);
if nargin<5 || isempty(iterations), iterations=30; end
%iterate DIM equations to determine neural responses
fprintf(1,'dim_conv: ');
for t=1:iterations
fprintf(1,'.%i.',t);
%update error-detecting neuron responses
R=zeros(a,b,nInputChannels,'single');
for j=1:nChannels
%calc predictive reconstruction of the input
for i=1:nMasks
%sum reconstruction over each RF type
if ~(isempty(v{i}(:,:,j)) || iszero(v{i}(:,:,j)) || isempty(Y(:,:,i)) || iszero(Y(:,:,i))) %skip empty filters and response arrays: they don't add anything
R(:,:,j)=R(:,:,j)+conv2(Y(:,:,i),v{i}(:,:,j),'same');
end
end
end
R(R<0)=0;
%calc error between reconstruction and actual input
E=X./max(epsilon2,R);
%update prediction neuron responses
for i=1:nMasks
input=zeros(a,b,'single');
for j=1:nChannels
%sum inputs to prediction neurons from each channel
if ~(isempty(w{i}(:,:,j)) || iszero(w{i}(:,:,j)) || isempty(E(:,:,j)) || iszero(E(:,:,j))) %skip empty filters and error arrays: they don't add anything
input=input+conv2(E(:,:,j),w{i}(:,:,j),'same');
end
end
%modulate prediction neuron response by current input:
Y(:,:,i)=max(epsilon1,Y(:,:,i)).*input;
end
Y(Y<0)=0;
end
disp(' ')
if nargin>=6 && ~isempty(trueRange)
if nargout>1
R=R(trueRange{1},trueRange{2},:);
E=E(trueRange{1},trueRange{2},:);
end
Y=Y(trueRange{1},trueRange{2},:);
end
function w=pad_to_make_odd(w)
[n,m,l]=size(w);
if n==2*floor(n/2)
%if n is even, add another row to make 1st dimension odd
w=[w;zeros(1,m,l)];
end
[n,m,l]=size(w);
if m==2*floor(m/2)
%if m is even, add another column to make 2nd dimension odd
w=[w,zeros(n,1,l)];
end
function X=trimarray_symmetrically(X)
%removes rows/columns of zeros from the edges of an array. Does so
%symmetrically, i.e. so the same number of rows are removed from the top as the
%bottom, and so that the same number of colums are removed from both the right
%and left edges.
tmp=sum(sum(abs(X),2));
cropA=min(min(find(tmp>0))-1,size(X,1)-max(find(tmp>0)));
tmp=sum(sum(abs(X),1));
cropB=min(min(find(tmp>0))-1,size(X,2)-max(find(tmp>0)));
X=X(cropA+1:size(X,1)-cropA,cropB+1:size(X,2)-cropB,:);