function  [Sest,Pest,err,itr]=obdBCS(A,B,k,M)
% [Sest,Pest,err,itr]=obdBCS(A,B,k,M)
% Runs the OBD-BCS algorithm.
%Inputs:
%    A - the measurements matrix (mast be a union of orthogonal bases).
%    B - the measurements. 
%    k - the maximal number of nonzeros in the columns of S.
%    M - the ratio between the number of blocks in P and the number of
%           block in A. By defult M=2;
%Outputs:
%    Sest - the estimated sparse matrix.
%    Pest - the estimated basis.
%    err - the error ||B-APS||^2.
%    itr - number of iterations.
% 
%Note that M mast be an integer larger then 1, and that the number of rows in B must be an
%integer multiple of M.
%
% Before running this function the OMP package of Ron Rubinstein
% (Computer Science Department Technion) must be installed. It can be
% downloaded from: http://www.cs.technion.ac.il/~ronrubin/software.html

addpath('ompbox') % change to the correct path if needed

%% Parameters 
n=size(B,1);   % number of measurements 
N=size(B,2);   %number of signals
m=size(A,2);   %length of the signals
L=m/n;               % compression ratio
MaxItr=50;    % maximal number of iterations
tol=0.001;       %tolerance

if (mod(M,1)) || M<2
    error('M must be an integer larger then 1')
end
if mod(n,M)
    error('The number of rows in B must be an integer multiple of M.')
end
    
%% Initialization
itr=0;              % iteration index
Sest=zeros(m,N); 
Pest=eye(m); 

%% Algorithm
while (itr<MaxItr)
    Sest_old=Sest;
    Pest_old=Pest;
    
    %Sparse coding
    D=A*Pest;
    for j=1:m                           %dictionary normalization (for OMP)
        Dnorm(j)=1./norm(D(:,j));
    end
    D=D*diag(Dnorm);
    Sest=diag(Dnorm)*omp(D,B,D'*D,k);  
    
    %Basis update
    for i=1:(M*L)
        idx=((i-1)*n/M+1):i*n/M;
        Ptemp=Pest;
        Ptemp(idx,idx)=zeros(n/M);
        Bi=B-A*Ptemp*Sest;
        [u,s,v]=svd(Sest(idx,:)*Bi'*A(:,idx));
        Pest(idx,idx)=v*u';
    end
    
    itr=itr+1;
    if norm(Sest_old-Sest,'fro')+norm(Pest_old-Pest,'fro')<tol     %change from last iteration
        break
    end
end

%% Final calculation of Sest
D=A*Pest;
for j=1:m
    Dnorm(j)=1./norm(D(:,j));
end
D=D*diag(Dnorm);
Sest=diag(Dnorm)*omp(D,B,D'*D,k);

err=norm(B-A*Pest*Sest,'fro');  %error


