function [x,r,normR,residHist, errHist] = BOMP( A, b, d, k, errFcn, opts )
%   uses the Block Orthogonal Matching Pursuit algorithm (BOMP)
%   to estimate the solution to the equation
%       b = A*x     (or b = A*x + noise )
%   where there is prior information that x is block-sparse.
%
%   "A" may be a matrix, or it may be a cell array {Af,At}
%   where Af and At are function handles that compute the forward and transpose
%   multiplies, respectively.
%
% [x,r,normR,residHist,errHist] = OMP( A, b, k, errFcn, opts )
%   is the full version.
% Outputs:
%   'x' is the block k-sparse estimate of the unknown signal
%   'r' is the residual b - A*x
%   'normR' = norm(r)
%   'residHist'     is a vector with normR from every iteration
%   'errHist'       is a vector with the outout of errFcn from every iteration
%
% Inputs:
%   'A'     is the measurement matrix
%   'b'     is the vector of observations
%   'd'     is the block size
%   'k'     is the estimate of the sparsity (you may wish to purposefully
%              over- or under-estimate the sparsity, depending on noise)
%              N.B. k < size(A,1) is necessary, otherwise we cannot
%                   solve the internal least-squares problem uniquely.
%
%   'k' (alternative usage):
%           instead of specifying the expected sparsity, you can specify
%           the expected residual. Set 'k' to the residual. The code
%           will automatically detect this if 'k' is not an integer;
%           if the residual happens to be an integer, so that confusion could
%           arise, then specify it within a cell, like {k}.
%
%   'errFcn'    (optional; set to [] to ignore) is a function handle
%              which will be used to calculate the error; the output
%              should be a scalar
%
%   'opts'  is a structure with more options, including:
%       .printEvery = is an integer which controls how often output is printed
%       .maxiter    = maximum number of iterations
%       .slowMode   = whether to compute an estimate at every iteration
%                       This computation is slower, but it allows you to
%                       display the error at every iteration (via 'errFcn')
%
%       Note that these field names are case sensitive!
%
% If you need a faster implementation, try the very good C++ implementation
% (with mex interface to Matlab) in the "SPAMS" toolbox, available at:
%   http://www.di.ens.fr/willow/SPAMS/
% The code in SPAMS is precompiled for most platforms, so it is easy to install.
% SPAMS uses Cholesky decompositions and uses a slightly different
%   updating rule to select the next atom.
%
% Stephen Becker, Aug 1 2011.  srbecker@alumni.caltech.edu
% Updated Dec 12 2012, fixing bug for complex data, thanks to Noam Wagner.
%   See also CoSaMP, test_OMP_and_CoSaMP
%
% Hadas Frostig, Apr 27 2017, converted code from performing OMP to BOMP.
% Can still be used for OMP with d=1. Slow mode not supported in this
% version.

if nargin < 6, opts = []; end
if ~isempty(opts) && ~isstruct(opts)
    error('"opts" must be a structure');
end

    function out = setOpts( field, default )
        if ~isfield( opts, field )
            opts.(field)    = default;
        end
        out = opts.(field);
    end

% slowMode    = setOpts( 'slowMode', false );
printEvery  = setOpts( 'printEvery', 50 );

% What stopping criteria to use? either a fixed # of iterations,
%   or a desired size of residual:
target_resid    = -Inf;
%target_resid=1e-12;
if iscell(k)
    target_resid = k{1};
    k   = size(b,1);
elseif k ~= round(k)
    target_resid = k;
    k   = size(b,1);
end
% (the residual is always guaranteed to decrease)
if target_resid == 0
        if printEvery > 0 && printEvery < Inf
            disp('Warning: target_resid set to 0. This is difficult numerically: changing to 1e-12 instead');
        end
    target_resid    = 1e-12;
end


if nargin < 5
    errFcn = [];
elseif ~isempty(errFcn) && ~isa(errFcn,'function_handle')
    error('errFcn input must be a function handle (or leave the input empty)');
end

At  = @(x) A'*x;    % At is a row vector

% -- Intitialize --
% start at x = 0, so r = b - A*x = b
r           = b;
normR       = norm(r);
Ar          = At(r);            % compute the inner product of A and the residual
N           = size(Ar,1);       % number of atoms
M           = size(r,1);        % size of atoms
n           = N/d;              % number of blocks
if k > M
    error('K cannot be larger than the dimension of the atoms');
end

if rem(N,d)~=0
    error('Non integral amount of blocks');
end

x           = zeros(N,1);

indx_set    = zeros(d*k,1);   % Set of indicies of all chosen atoms
A_T         = zeros(M,d*k);   % Set of normalized orthogonal chosen atoms (current support set)
residHist   = zeros(k,1);
errHist     = zeros(k,1);

for kk = 1:k
    disp(['Current iteration = ',num2str(kk)]);
    % -- Step 1: find new index and block to add
    %     [dummy,ind_new]     = max(abs(Ar));
    Ar_block_form = reshape(Ar,d,n);
    in = sqrt(sum(Ar_block_form.^2,1));
    [~,b_ind_new]     = max(in);
    disp(['Chosen block index = ', num2str(b_ind_new)]);
     % Check if this index is already in
    if ismember( (b_ind_new-1)*d + 1, indx_set(1:d*kk-1) )
        disp('Should not happen... entering debug');
        keyboard
    end
    indx_set(d*(kk-1)+1:d*kk)    = d*(b_ind_new-1)+1:d*b_ind_new;  % Add new atom index into the set of indicies of previously chosen atoms
    %     if LARGESCALE
    %         unitVector(ind_new)     = 1;
    %         atom_new                = Af( unitVector );
    %         unitVector(ind_new)     = 0;
    %     else
    block_new    = A(:,d*(b_ind_new-1)+1:d*b_ind_new); % Select the chosen block
    %     end
    %     A_T_nonorth(:,d*(kk-1)+1:d*kk)   = block_new;     % before orthogonalizing and such
    
    
    
    % -- Step 2: update residual
    
    %     if slowMode
    %         % The straightforward way:
    %         x_T = A_T_nonorth(:,1:kk)\b;
    
    % or, use QR decomposition:
    %         if kk < 10
    % %             [Q,R] = qr( A_T_nonorth(:,1:kk), 0 );
    %             [Q,R] = qr( A_T_nonorth(:,1:kk)); % need full "Q" matrix to use "qrinsert"
    %             %  For this reason, "qrinsert" is not efficient
    %         else
    %             % from now on, we use the old QR to update the new one
    %             [Q,R] = qrinsert( Q, R, kk, atom_new );
    %         end
    %         x_T = R\(R'\(A_T_nonorth(:,1:kk)'*b));
    
    %
    %         x( indx_set(1:kk) )   = x_T;
    %         r   = b - A_T_nonorth(:,1:kk)*x_T;
    %     else
    
    % First, orthogonalize 'atom_new' against all previous atoms
    % We use MGS
    for l = 1:d
        for j = 1:d*kk
                block_new(:,l)  = block_new(:,l) - (A_T(:,j)'*block_new(:,l))*A_T(:,j); % Othogonalize the atoms of the new block                                                                    % against all other atoms
        end
        block_new(:,l)  = block_new(:,l)/norm(block_new(:,l));    % Normalize new atoms
        A_T(:,d*(kk-1)+l)  = block_new(:,l);    % Inset into A_T for the support set update
    end
    
    % Second, update current support set by solving least-squares problem: min||b - A_t*x_t||^2 (which is now very easy
    % since A_T(:,1:d*kk) is orthogonal making A_T^(-1) = A_T')
    x_T    = A_T(:,1:d*kk)'*b;
    x( indx_set(1:d*kk) )   = x_T;      % note: indx_set is guaranteed to never shrink
    
    % Third, update residual:
    r   = b - A_T(:,1:d*kk)*x_T;
    
    % N.B. This err is unreliable, since this "x" is not the same
    %   (since it relies on A_T, which is the orthogonalized version).
    %     end
       
    normR   = norm(r);
    % -- Print some info --
    %PRINT   = ( ~mod( kk, printEvery ) || kk == k );
    %     PRINT=false;
    %     if printEvery > 0 && printEvery < Inf && (normR < target_resid )
    %         % this is our final iteration, so display info
    %         PRINT = true;
    %     end
    
    %     if ~isempty(errFcn) && slowMode
    %         er  = errFcn(x);
    %         %if PRINT, fprintf('%4d, %.2e, %.2e\n', kk, normR, er ); end
    %         errHist(kk)     = er;
    %     else
    %         %if PRINT, fprintf('%4d, %.2e\n', kk, normR ); end
    %         % (if not in SlowMode, the error is unreliable )
    %     end
    residHist(kk)   = normR;
    
    
    
    if normR < target_resid
        %         if PRINT
        %             fprintf('Residual reached desired size (%.2e < %.2e)\n', normR, target_resid );
        %         end
        break;
    end
    
    if kk < k
        Ar  = At(r); % prepare for next round
    end
    
end
% if (target_resid) && ( normR >= target_resid )
%     fprintf('Warning: did not reach target size of residual\n');
% end


% if ~slowMode  % (in slowMode, we already have this info)
%     % For the last iteration, we need to do this without orthogonalizing A
%     % so that the x coefficients match what is expected.
%     x_T = A_T_nonorth(:,1:kk)\b;
%     x( indx_set(1:kk) )   = x_T;
% end
% r       = b - A_T_nonorth(1:kk)*x_T;
% normR   = norm(r);

end % end of main function
