You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

141 lines
4.5 KiB
Matlab

function [Y,R,E,Ytrace,Rtrace,Etrace]=dim_activation_conv(w,X,iterations,Y)
% w = a cell array of size {N,M}, where N is the number of distinct neuron types
% and M is the number of input channels. Each element {i,j} of the cell
% array is a 2-dimensional matrix (a convolution mask) specifying the
% synaptic weights for neuron type i's RF in input channel j.
% X = a cell array of size {M}. Each element {j} of the cell array is a
% 2-dimensional matrix specifying the bottom-up input to that channel j of
% the current processing stage (external inputs targetting the error
% nodes). If inputs arrive from more than one source, then each source
% provides a different cell in the array. Input values that change over
% time can be presented using a 3-d matrix, where the 3rd dimension is time.
% Y = a cell array of size {N}. Each element {i} of the cell array is a
% 2-dimensional matrix specifying prediction node activations for type i
% neurons.
% R = a cell array of size {M}. Each element {j} of the cell array is a
% 2-dimensional matrix specifying the reconstruction of the input channel j.
% These values are also the top-down feedback that should modulate
% activity in preceeding processing stages of a hierarchy.
% E = a cell array of size {M}. Each element {j} of the cell array is a
% 2-dimensional matrix specifying the error in the reconstruction of the
% input channel j.
[a,b,z]=size(X{1});
[nMasks,nChannels]=size(w);
[c,d]=size(w{1,1});
if nargin<3, iterations=50; end
if nargin<4 | isempty(Y), initY=1; else initY=0; end
%set parameters
epsilon=1e-9;
psi=5000;
epsilon1=0.0001; %>0.001 this becomes significant compared to y and hence
%produces sustained responses and more general suppression
epsilon2=500*epsilon1*psi;%this determines scaling of initial transient response
%(i.e. response to linear filters).
%try to speed things up
if exist('convnfft')==2 && (max(c,d)>=50-max(a,b) || (min(c,d)>10 && min(a,b)>10))
conv_fft=1;%use fft version of convolution for large images and/or masks
else
conv_fft=0;%use standard conv2 function for smaller images and/or masks
end
%also convert data to single precision
%for convnfft single is always faster than double
%for conv2 single is faster than double if image is larger than mask
for j=1:nChannels
X{j}=single(X{j});
end
%normalize weights and initialise outputs
for i=1:nMasks
%initialise prediction neuron outputs to zero
if initY
Y{i}=zeros(a,b,'single');
end
A{i}=zeros(a,b,'single');
%calculate normalisation values by taking into account all weights
%contributing to each RF type
sumVal=0;
maxVal=0;
for j=1:nChannels
w{i,j}=single(w{i,j});
sumVal=sumVal + sum(sum(w{i,j}));
maxVal=max(maxVal,max(max(w{i,j})));
end
sumVal=sumVal./psi;
maxVal=maxVal./psi;
%apply normalisation to calculate feedforward and feedback weight values.
%Note: FF weights are flipped versions, so that conv2 can be used to apply the
%filtering
for j=1:nChannels
wFF{i,j}=fliplr(flipud(w{i,j}))./(epsilon+sumVal);
wFB{i,j}=w{i,j}./(epsilon+maxVal);
end
end
fprintf(1,'dim_conv(%i): ',conv_fft);
%iterate to determine steady-state response
for t=1:iterations
fprintf(1,'.%i.',t);
%calculate input from lateral and feedback connections
for i=1:nMasks
Y{i}=Y{i}.*(1+A{i});
end
%update error units
for j=1:nChannels
R{j}=zeros(a,b,'single');%reset reconstruction of input
if conv_fft==1
for i=1:nMasks
R{j}=R{j}+convnfft(Y{i},wFB{i,j},'same');%sum reconstruction over
%each RF type
end
else
for i=1:nMasks
R{j}=R{j}+conv2(Y{i},wFB{i,j},'same');%sum reconstruction over
%each RF type
end
end
E{j}=X{j}(:,:,min(t,z))./(epsilon2+R{j});%calc reconstruction error
if nargout>5
Etrace{j}(:,:,t)=E{j}.*psi;%record response over time
end
if nargout>4
Rtrace{j}(:,:,t)=R{j}./psi;%record response over time
end
end
%update outputs
for i=1:nMasks
input=0;
if conv_fft==1
for j=1:nChannels
input=input+convnfft(E{j},wFF{i,j},'same');%sum input to prediction
%node from each channel
end
else
for j=1:nChannels
input=input+conv2(E{j},wFF{i,j},'same');%sum input to prediction
%node from each channel
end
end
Y{i}=(epsilon1+Y{i}).*input;%modules prediction node response by input
Y{i}=max(0,Y{i});%ensure no negative values creep in!
if nargout>3
Ytrace{i}(t)=Y{i}(ceil(a/2),ceil(b/2));%record response over time
end
end
end
for j=1:nChannels
R{j}=R{j}./psi;
E{j}=E{j}.*psi;
end
disp(' ');