function [X_MAP,S_MAP,k_MAP] = BM_MAP_OMP(Y,var_e,model_params)
%BM_MAP_OMP computes an OMP-like aproximation for the BM-based MAP estimator
%of the representation coefficients.
% BM - Boltzmann Machine
% BM-based model - BM prior for the sparsity pattern and Gaussian nonzero coefficients.
% This is a fast version of the algorithm which makes efficient computations
% of inverse matrices and determinants using algebric shortcuts. 
% =====================================================================================
% Input:
% Y - an n-by-N matrix consisting of N noisy signals. 
% var_e - the variance of the gaussian distribution of the additive noise.
% model_params - parameters for the stochastic model. Required fields:
% model_params.dictionary - a matrix of size n-by-m consisting of the dictionary atoms.
% model_params.variances - a vector of size m-by-1 consisting of the variances of the 
% gaussian distributions of the nonzero representation coefficients.
% model_params.W, model_params.b - the Boltzmann parameters: an interaction matrix 
% of size m-by-m and a bias vector of size m-by-1.
% =====================================================================================
% Output:
% X_MAP - an m-by-N matrix consisting of the recovered representation vectors. 
% S_MAP - an m-by-N matrix consisting of the the sparsity patterns.
% k_MAP - a vector of size N-by-1 consisting of the lengths of the recovered supports.
% =====================================================================================
% Tomer Faktor
% Department of Electrical Engineering
% Technion, Haifa 32000 Israel
% tomerfa@tx.technion.ac.il
%
% August 2011
% =====================================================================================
N=size(Y,2);
A=model_params.dictionary;
[n,m]=size(A);
indices=1:m;
var_x=model_params.variances;
W=model_params.W;
b=model_params.b;
err=1e-3;
maxNumCoef=n/2;
X_MAP=zeros(m,N);
S_MAP=-1*ones(m,N);
k_MAP=zeros(1,N);
hh = waitbar(0,'Approximating MAP via OMP-like approach');
for l=1:N
    if ~rem(l,100)
        waitbar(l/N,hh)
    end
    % Initialization and setting the parameters  for the stopping rule
    all_inds=ones(m,1);
    s_recov=[];
    x_recov=zeros(m,1);
    y=Y(:,l);
    norm_r=norm(y);   
    s_recov_prev=[];
    x_recov_prev=zeros(m,1);
    S0=-ones(m,1);
    MAP_arg_s_prev=0.5*S0'*W*S0+(b'-0.25*log(var_x'/var_e))*S0;
    % Main loop
    i=0;
    while norm_r>err && i < maxNumCoef
        i=i+1;
        h=m+1-i;
        s_curr=zeros(h,i);
        if i>1
            s_curr(:,1:i-1)=repmat(s_recov,[h,1]);
        end
        rem_inds=find(all_inds>0);
        s_curr(:,i)=(rem_inds)';
        val=zeros(h,1);
        x_s=zeros(i,h);
        if i==1
            inv_Q_curr=zeros(1,m);
        else
            inv_Q_curr=zeros(i,i,h);
        end     
        det_Q_curr=zeros(1,h);
        for j=1:h
            p=rem_inds(j);
            As=A(:,s_curr(j,:));
            if i==1
                det_Q_curr(j)=A(:,s_curr(j,:))'*A(:,s_curr(j,:))+var_e/var_x(p);
                inv_Q_curr(j)=1/det_Q_curr(j);
                x_s(:,j)=inv_Q_curr(j)*As'*y;
            else
                B=A(:,s_curr(j,1:i-1))'*A(:,p);
                F=A(:,p)'*A(:,p)+var_e/var_x(p);
                V1=1/(F-B'*inv_Q_prev*B);
                V2=inv_Q_prev+inv_Q_prev*B*V1*B'*inv_Q_prev;
                V3=-inv_Q_prev*B*V1;
                inv_Q_curr(:,:,j)=[V2,V3;V3',V1];
                x_s(:,j)=inv_Q_curr(:,:,j)*As'*y;
                det_Q_curr(j)=det_Q_prev/V1;
            end
            S_vec=-1*ones(m,1);
            S_vec(s_curr(j,:))=1;  
            new_ind=rem_inds(j);
            other_inds=(indices~=new_ind);
            val(j)=y'*As*x_s(:,j)/(2*var_e)-0.5*log(abs(det_Q_curr(j)))-0.5*log(var_x(p))+...
                2*W(p,other_inds)*S_vec(other_inds(:))+2*b(p);
        end
        [max_val,ind_max]=max(val);
        ind_max=ind_max(1);
        add2s=rem_inds(ind_max);
        s_recov=[s_recov,add2s];
        all_inds(add2s)=0;
        x_recov(s_recov)=x_s(:,ind_max);
        if i==1
            inv_Qs=inv_Q_curr(ind_max);
        else
            inv_Qs=squeeze(inv_Q_curr(:,:,ind_max));
        end
        norm_r=norm(y-A(:,s_recov)*x_s(:,ind_max));
        S_vec=-1*ones(m,1);
        S_vec(s_recov)=1;
        As=A(:,s_recov);
        det_Qs=det_Q_curr(ind_max);
        MAP_arg_s_curr=y'*As*x_recov(s_recov)/(2*var_e)-0.5*log(abs(det_Qs))+0.5*S_vec'*W*S_vec+...
            (b'-0.25*log(var_x'/var_e))*S_vec;
        if MAP_arg_s_curr<MAP_arg_s_prev
            x_recov=x_recov_prev;
            s_recov=s_recov_prev;
            break;
        else
            inv_Q_prev=inv_Qs;
            det_Q_prev=det_Qs;
            MAP_arg_s_prev=MAP_arg_s_curr;
            x_recov_prev=x_recov;
            s_recov_prev=s_recov;
        end
    end
    X_MAP(:,l)=x_recov(:);
    S_MAP(s_recov,l)=1;
    k_MAP(l)=numel(s_recov);
end
close(hh)