function [X_MAP,S_MAP,k_MAP] = BM_unitary_MAP_message_passing(Y,var_e,model_params)
%BM_unitary_MAP_message_passing computes the exact BM-based MAP estimator of the 
%representation coefficients for a unitary dictionary and a banded interaction matrix. 
% BM - Boltzmann Machine
% BM-based model - BM prior for the sparsity pattern and Gaussian nonzero coefficients.
% This algorithm is based on message passing techniques.
% =====================================================================================
% Input:
% Y - an n-by-N matrix consisting of N noisy signals. 
% var_e - the variance of the gaussian distribution of the additive noise.
% model_params - parameters for the stochastic model. Required fields:
% model_params.dictionary - a matrix of size n-by-n consisting of the dictionary atoms.
% The dictionary should be square and unitary.
% model_params.variances - a vector of size n-by-1 consisting of the variances of the 
% gaussian distributions of the nonzero representation coefficients.
% model_params.W, model_params.b - the Boltzmann parameters: an interaction matrix
% of size n-by-n and a bias vector of size n-by-1. The interaction matrix should be
% banded. 
% =====================================================================================
% Output:
% X_MAP - an n-by-N matrix consisting of the restored representation vectors. 
% S_MAP - an n-by-N matrix consisting of the sparsity patterns.
% k_MAP - a vector of size N-by-1 consisting of the lengths of the recovered supports.
% =====================================================================================
% Tomer Faktor
% Department of Electrical Engineering
% Technion, Haifa 32000 Israel
% tomerfa@tx.technion.ac.il
%
% August 2011
% =====================================================================================
[n,N]=size(Y);
A=model_params.dictionary;
% Check input parameters
G=A'*A;
X_MAP=[];
S_MAP=[];
k_MAP=[];
if size(A,2)~=n
    display('Invalid input - the dictionary should be square')
    return;
elseif sum(sum(abs(G-eye(n))>1e-10))>0
    display('Invalid input - the dictionary should be unitary')
    return;
end
var_x=model_params.variances;
W=model_params.W;
b=model_params.b;
[r,c]=find(abs(triu(W))>1e-10);
if ~isempty(r)
    L=max(c-r); % number of non-zero diagonals in the upper triangle part of W
else
    L=0;
end
if L>2*log2(n) % for L=2*log2(n) the computational complexity is O(n^3)
  display('Invalid input - the interaction matrix should be banded')
  return;      
end
% Prepare for message passing
k=floor((n-L-1)/2);
S_clique=-ones(L+1,2^(L+1));
v=0:2^(L+1)-1;
for i=L+1:-1:1
    S_clique(i,:)=2*(floor(v/2^(i-1)))-1;
    v=rem(v,2^(i-1));
end
% Perform message passing
S_MAP=-1*ones(n,N);
hh=waitbar(0,'Computing exact MAP via message passing');
for l=1:N
    if ~rem(l,100)
        waitbar(l/N,hh)
    end
    S1_MAP=-ones(n,1);
    y=Y(:,l);
    q=b+0.25*(var_x./(var_e*(var_e+var_x)).*(A'*y).^2-log(1+var_x/var_e));
    [clique_potentials,clique_nodes] = compute_clique_props(W,q,S_clique);
    if L>0
        [m_forward,m_backward] = compute_messages(clique_potentials,S_clique);
        [val_MAP,ind_MAP]=max(repmat(m_forward(k+1,:),1,2)+...
            reshape(repmat(m_backward(k+1,:),2,1),1,2^(L+1))+clique_potentials{k+1});
        S1_MAP(clique_nodes{k+1})=S_clique(:,ind_MAP);
        for i=k:-1:1
            v=(0.5*(S1_MAP(i+1:i+L)+1)).*(2.^(1:L)');
            ind1=sum(v);
            ind2=sum(v(1:end-1));
            S1_MAP(i)=2*(clique_potentials{i}(ind1+2)+m_forward(i,ind2+2)>=...
                clique_potentials{i}(ind1+1)+m_forward(i,ind2+1))-1;
        end
        for i=k+2:n-L
            v1=(0.5*(S1_MAP(i:i+L-1)+1)).*(2.^(0:L-1)');
            v2=(0.5*(S1_MAP(i+1:i+L-1)+1)).*(2.^(0:L-2)');
            ind1=sum(v1);
            ind2=sum(v2);
            S1_MAP(i+L)=2*(clique_potentials{i}(ind1+1+2^L)+m_backward(i,ind2+1+2^(L-1))>=...
                clique_potentials{i}(ind1+1)+m_backward(i,ind2+1))-1;
        end
    else
        for i=1:n
            [val_MAP,ind_MAP]=max(clique_potentials{i});
            S1_MAP(i)=S_clique(ind_MAP);
        end
    end
    S_MAP(:,l)=S1_MAP;
end
close(hh)
k_MAP=sum(S_MAP==1);
% Compute estimates for the represenations using the oracle formula
X_MAP=unitary_oracle_formula(Y,var_e,model_params,S_MAP);