% Script running SoS sampling and reconstruction scheme for Dirac deltas. 
% This is a clean algorithm, since when dealing with deltas, neither discretization  digital filtering are needed
% Noise performance is simulated as well

% Created by Ronen Tur, 2009

close all
clear all
%% User Configurable Parameters
% ---------------------------------------------------------
% ---------------------------------------------------------
K=3; % number of Dirac deltas in the input signal
over_sampling_factor = 1; 
    % over-sampling factor (1 means no oversmampling), 
    % the no. of samples taken in simulation is N=2*K*over_sampling_factor+1
SNR_scale = 5:2:35; % SNR scale [dB] for experiment
num_of_experiments = 40;
solve_previous_methods = 1; % This flags whether to solve previous methods (B&E-splines and Gaussian)
% Hard thresholding parameters
bspline_threshold = 3; % spline hard-thresholding parameter, in units of sigma_noise; used to improve spline performance
% ---------------------------------------------------------
% ---------------------------------------------------------
%% Internal Setup Params
N=2*K*over_sampling_factor + 1; % number of samples, must be odd!
tau=1; % Time-window in which all deltas lie; arbitrary units

if solve_previous_methods
    solve_bspline = 1; % This flag determines whether to solve using B-spline method
    solve_espline = 1; % This flag determines whether to solve using e-spline method
else
    solve_bspline = 0; % This flag determines whether to solve using B-spline method
    solve_espline = 0; % This flag determines whether to solve using e-spline method
end

save2file_flag = 0; % determines whether to save workspace for future use
plot_flag =0; % plots additional figure for debug

other_threshold = 0; % other methods hard-thresholding parameter, in units of sigma_noise; no hard-thresh. by default

B=N/tau; % sinc bandwidth;
%T=tau/N; % sampling period
T=tau/N;
M=floor(B*tau/2);
% global error variables init
error_3P_global = zeros(size(SNR_scale));
error_3P_global_amplitudes = zeros(size(SNR_scale));
error_Cadzow_3P_global = zeros(size(SNR_scale));
error_sinc_global = zeros(size(SNR_scale));
error_Cadzow_sinc_global = zeros(size(SNR_scale));
error_gaussian_global = zeros(size(SNR_scale));
error_Cadzow_gaussian_global = zeros(size(SNR_scale));
error_spline_global = zeros(size(SNR_scale));
error_Cadzow_spline_global = zeros(size(SNR_scale));
error_spline_direct_global = zeros(size(SNR_scale));
error_Espline_direct_global = zeros(size(SNR_scale));
sigma_spline_average = zeros(size(SNR_scale));
sigma_Espline_average = zeros(size(SNR_scale));

gauss_kernel_factor = 0.32; % internal

%% B-spline initialization
if solve_bspline
    spline_order = 2*K-1;
    spline_support = spline_order+1;
    spline_sample_times = (-(floor(spline_support/2)+(N-1)/2):((floor(spline_support/2)+(N-1)/2)));
    spline_coefficients = zeros(length(spline_sample_times),spline_order+1);
    for i_moment = 1:spline_order+1
        spline_coefficients(:,i_moment) = CalcSplineCoeff(spline_order,i_moment-1,length(spline_sample_times));
    end
end
%% E-spline initialization
% Important note: the E-spline functions are not centered!!! They start from
% t=0 and continue from there on. (The B-splines are centered!)
if solve_espline
    Espline_order = 2*K-1;
    switch K
        case 1
            espline_func_pointer = @m_espline_func_order1;
        case 2
            espline_func_pointer = @m_espline_func_order3;
        case 3
            espline_func_pointer = @m_espline_func_order5;
        case 4
            espline_func_pointer = @m_espline_func_order7_maple;
        case 5
            espline_func_pointer = @m_espline_func_order9_maple;
    end
    
    Espline_support = Espline_order+1;
    Espline_sample_times_old = ( (-(floor(Espline_support/2)+(N-1)/2)):((floor(Espline_support/2)+(N-1)/2)) );
    Espline_sample_times = Espline_sample_times_old;
    load(strcat('CalcEsplineFuncResults_Order',int2str(Espline_order)))
    [Espline_coeff_matrix,coeff_index_set] = CalcEsplineCoeffMain(Espline_order,alpha_vec,length(Espline_sample_times),espline_func_pointer);
end
%% Choose time-delays, tk, and amplitudes, ak

% uniformly distributed
temp = linspace(0,tau,K+2);
tk = temp(2:(end-1)).';
xk = ones(K,1);
 
% % % % By hand
% tk = [0.3;0.66];
% xk=[1;1];
 
% % random
% tk = sort(tau*rand(K,1));
% xk = ones(K,1);
% % xk = (2*(random('unid',2,K,1) == 1) - 1).*(0.5 + 0.5*abs(randn(K,1))); % Negative amplitudes too

%% Filter parameters for sum of sincs filter
K_set = -M:M;
b_k = ones(size(K_set))*2;
b_k = ones(size(K_set)).*abs(K_set)+1;

% b_k = hamming(length(K_set));
%% Sample Times
T = tau/N;
sample_times = (0:(N-1))*T;
[SAMPLE_TIMES,TK] = meshgrid(sample_times,tk);

%% Sum of sincs 3 period filter samples
if plot_flag
    disp('----------------------------------------------------------------------------------------------------------');
    disp('Using 3 period filter');
end
c_n_3P = zeros(length(sample_times),1);
num_periods = 3;
for tk_index = 1:length(tk)
    c_n_3P = c_n_3P + xk(tk_index)*SumOfSincsFilter(sample_times-tk(tk_index),tau,K_set,b_k,plot_flag,num_periods);
end

%% Periodic sinc samples
if plot_flag
    disp('----------------------------------------------------------------------------------------------------------');
    disp('Using periodic sinc filter');
end
nT_minus_tk = (SAMPLE_TIMES-TK).';
c_n_sinc = Phi(nT_minus_tk,tau,B)*xk;

%% Gaussian kernel samples
if solve_bspline
    if K==1
        average_spacing = 0.5*tau;
    else
        average_spacing = mean(diff(tk));
    end
    sigma_gauss_kernel = gauss_kernel_factor*average_spacing;
    c_n_gaussian = gaussian_kernel(nT_minus_tk,sigma_gauss_kernel)*xk;
end
%% spline samples
if solve_bspline
    [SPLINE_SAMPLE_TIMES,SPLINE_TK] = meshgrid(spline_sample_times,tk-tau/2);
    spline_nT_minus_tk = (SPLINE_TK/T-SPLINE_SAMPLE_TIMES).';
    c_n_spline = calcspline(spline_nT_minus_tk,spline_order)*xk;
end
%% E-spline samples
if solve_espline
    [ESPLINE_SAMPLE_TIMES,ESPLINE_TK] = meshgrid(coeff_index_set,tk-tau/2);
    Espline_nT_minus_tk = (ESPLINE_SAMPLE_TIMES-ESPLINE_TK/T).';
    c_n_Espline = espline_func_pointer(Espline_nT_minus_tk)*xk;
end
%% Run in a loop for different experiments
h = waitbar(0,sprintf('Out of %d iterations',num_of_experiments),'Name','Interation no. progress',...
            'CreateCancelBtn',...
            'setappdata(gcbf,''canceling'',1)');
setappdata(h,'canceling',0)
for i_experiment = 1:num_of_experiments
    if getappdata(h,'canceling')
        break
    end
    waitbar(i_experiment / num_of_experiments)
    if plot_flag
        disp(sprintf('Iteration no.%0.3g',i_experiment))
    end

if i_experiment == 1
    randn_state_noise = randn('state');
end
sinc_and_SOSfilter_noise = randn(length(sample_times),1);
if solve_bspline
    bspline_noise = randn(length(spline_sample_times),1);
end
if solve_espline
    Espline_noise = bspline_noise;
end
%% Run a loop on different SNR
for i_SNR = 1:length(SNR_scale)
digital_SNR = 10^(SNR_scale(i_SNR)/10); % linear SNR

%% Solve using 3 period filter
if digital_SNR<1e15
    sigma_3P = calc_sigma(c_n_3P,sinc_and_SOSfilter_noise,digital_SNR);
    c_n_3P_noisy = c_n_3P + sigma_3P*sinc_and_SOSfilter_noise;
    c_n_3P_noisy = c_n_3P_noisy.*(abs(c_n_3P_noisy)>other_threshold*sigma_3P); % hard thresholding with
else
    c_n_3P_noisy = c_n_3P;
end
[tk_FRI_3P,tk_FRI_Cadzow_3P,xk_FRI_3P,xk_FRI_Cadzow_3P]=Estimate_tk_xk(c_n_3P_noisy,K,tau,K_set,b_k,tk,xk,'SoS_filter');
error_3P_global(i_SNR) = error_3P_global(i_SNR) + 1/num_of_experiments*norm(tk_FRI_3P-tk)^2;
error_3P_global_amplitudes(i_SNR) = error_3P_global_amplitudes(i_SNR) + 1/num_of_experiments*norm(xk_FRI_3P-xk)^2;
error_Cadzow_3P_global(i_SNR) = error_Cadzow_3P_global(i_SNR) + 1/num_of_experiments*norm(tk_FRI_Cadzow_3P-tk)^2;
%% Solve using periodic sinc filter
if digital_SNR<1e15
    sigma_sinc = calc_sigma(c_n_sinc,sinc_and_SOSfilter_noise,digital_SNR);
    c_n_sinc_noisy = c_n_sinc + sigma_sinc*sinc_and_SOSfilter_noise;
    c_n_sinc_before_threshold = c_n_sinc_noisy;
    c_n_sinc_noisy = c_n_sinc_noisy.*(abs(c_n_sinc_noisy)>other_threshold*sigma_sinc);
else
    c_n_sinc_noisy = c_n_sinc;
end
b_k_sinc = ones(length(K_set),1);
[tk_FRI_sinc,tk_FRI_Cadzow_sinc,xk_FRI_sinc,xk_FRI_Cadzow_sinc]=Estimate_tk_xk(c_n_sinc_noisy,K,tau,K_set,b_k_sinc,tk,xk,'sinc');
error_sinc_global(i_SNR) = error_sinc_global(i_SNR) + 1/num_of_experiments*norm(tk_FRI_sinc-tk)^2;
error_Cadzow_sinc_global(i_SNR) = error_Cadzow_sinc_global(i_SNR) + 1/num_of_experiments*norm(tk_FRI_Cadzow_sinc-tk)^2;
%% Solve using Gaussian kernel
if solve_bspline
    if digital_SNR<1e15
        sigma_gaussian = calc_sigma(c_n_gaussian,sinc_and_SOSfilter_noise,digital_SNR);
        c_n_gaussian_noisy = c_n_gaussian + sigma_gaussian*sinc_and_SOSfilter_noise;
        c_n_gaussian_before_threshold = c_n_gaussian_noisy;
        c_n_gaussian_noisy = c_n_gaussian_noisy.*(abs(c_n_gaussian_noisy)>other_threshold*sigma_gaussian);
    else
        c_n_gaussian_noisy = c_n_gaussian;
    end
    b_k_gaussian = ones(length(K_set),1);
    c_n_gaussian_modified = c_n_gaussian_noisy.*exp( sample_times.^2 / (2*sigma_gauss_kernel^2) ).';
    uk_FRI_gaussian = annihilating_filter(c_n_gaussian_modified,K);
    tk_FRI_gaussian = sigma_gauss_kernel^2*log(uk_FRI_gaussian)/T;
    
    error_gaussian_global(i_SNR) = error_gaussian_global(i_SNR) + 1/num_of_experiments*norm(tk_FRI_gaussian-tk)^2;
end

%% Solve using splines
if solve_bspline==1
    if plot_flag
        disp('----------------------------------------------------------------------------------------------------------');
        disp('Using B-splines');
    end
    if digital_SNR<1e15
        sigma_spline = calc_sigma(c_n_spline,bspline_noise,digital_SNR);
        sigma_spline_average(i_SNR) = sigma_spline_average(i_SNR) + 1/num_of_experiments*sigma_spline;
        c_n_spline_noisy = c_n_spline + sigma_spline*bspline_noise;
        c_n_spline_noisy = c_n_spline_noisy.*(abs(c_n_spline_noisy)>bspline_threshold*sigma_spline);
    else
        c_n_spline_noisy = c_n_spline;
    end
    moments = spline_coefficients.'*c_n_spline_noisy;
    %% Solve b-splines via matrix inversion
    moments_calibrated = moments.*T.^[0:spline_order].';
    r = moments_calibrated( ((K-1):-1:0) + 1);
    c = moments_calibrated( ((K-1):(spline_order-1)) + 1);
    moments_matrix = toeplitz(c,r);
    ann_coeff = -pinv(moments_matrix)*moments_calibrated((K:spline_order)+1);
    my_roots = roots([1;ann_coeff]);
    tk_spline_direct = sort(my_roots) + tau/2;
    error_spline_direct_global(i_SNR) = error_spline_direct_global(i_SNR) + 1/num_of_experiments*norm(tk_spline_direct - tk)^2;
    xk_spline_direct = pinv(fliplr(vander(tk_spline_direct)).')*moments_calibrated(1:K);
    if plot_flag
        disp(sprintf('Error in apmlitudes, xk, Using estimated tk_spline_direct = %0.3g', norm(xk_spline_direct - xk)^2))
    end
    
%% Solve spline using FRI_classic function.
    if spline_order>=2*K
    [dummy,ann_filter_coeff] = FRI_classic(moments_calibrated,K,tau);
    tk_spline = sort(real(roots(ann_filter_coeff)))+tau/2;

    error_spline_global(i_SNR) = error_spline_global(i_SNR) + 1/num_of_experiments*norm(tk_spline - tk)^2;
    % Show results for reconstruction from moments
    if plot_flag
        disp('----------------------------------------------------------------------------------------------------------');
        disp(sprintf('Annihilating Filter - Squared error = %0.3g', norm(tk_spline - tk)^2))
    end
    [dummy,ann_filter_coeff_Cadzow] = FRI_Cadzow(moments_calibrated,B,T,K,tau);
    tk_spline_Cadzow = sort(real(roots(ann_filter_coeff_Cadzow)))+tau/2;

    error_Cadzow_spline_global(i_SNR) = error_Cadzow_spline_global(i_SNR) + 1/num_of_experiments*norm(tk_spline_Cadzow - tk)^2;
    if plot_flag
        disp(sprintf('Cadzow FRI - Squared error = %0.3g', norm(tk_spline_Cadzow - tk)^2))
    disp('------------------------------------------');
    end
    
    % % Estimate amplitudes xk    
    xk_spline = pinv(fliplr(vander(tk_spline)).')*moments_calibrated(1:K);
    if plot_flag
        disp(sprintf('Error in apmlitudes, xk, Using estimated tk_spline = %0.3g', norm(xk_spline - xk)^2))
    end
    end
end
%% Solve E-splines
if solve_espline==1
    if plot_flag
        disp('----------------------------------------------------------------------------------------------------------');
        disp('Using E-splines');
    end
    if digital_SNR<1e15
        sigma_Espline = calc_sigma(c_n_Espline,Espline_noise,digital_SNR);
        sigma_Espline_average(i_SNR) = sigma_Espline_average(i_SNR) + 1/num_of_experiments*sigma_Espline;
        c_n_Espline_noisy = c_n_Espline + sigma_Espline*Espline_noise;
    else
        c_n_Espline_noisy = c_n_Espline;
    end
    moments_Espline = Espline_coeff_matrix*c_n_Espline_noisy;
%% Solve e-splines via matrix inversion
    moments_Espline_calibrated = moments_Espline;%.*T.^[0:Espline_order].';
    r = moments_Espline_calibrated( ((K-1):-1:0) + 1);
    c = moments_Espline_calibrated( ((K-1):(Espline_order-1)) + 1);
    moments_Espline_matrix = toeplitz(c,r);
    ann_coeff_Espline = -pinv(moments_Espline_matrix)*moments_Espline_calibrated((K:Espline_order)+1);
    my_roots_Espline = roots([1;ann_coeff_Espline]);
    alpha_diff_vec = diff(alpha_vec);
    alpha_diff = alpha_diff_vec(1);
    tk_Espline = sort(mod(1j*angle(my_roots_Espline)/(alpha_diff)*T + tau/2,tau));

    error_Espline_direct_global(i_SNR) = error_Espline_direct_global(i_SNR) + 1/num_of_experiments*norm(tk_Espline - tk)^2;
end

end
end
delete(h) 
%% Display global results
disp('----------------------------------------------------------------------------------------------------------');
disp('Global results')
fprintf('Using SoS 3 period filter. Average times error = %0.3g \n', (error_3P_global)/tau)
fprintf('Using SoS 3 period filter with Cadzow. Average times error = %0.3g \n', (error_Cadzow_3P_global)/tau)
fprintf('Using sinc filter. Average times error = %0.3g \n', (error_sinc_global)/tau)
fprintf('Using sinc filter with Cadzow. Average times error = %0.3g \n', (error_Cadzow_sinc_global)/tau)
fprintf('Using gaussian filter. Average times error = %0.3g \n', (error_gaussian_global)/tau)
fprintf('Using gaussian filter with Cadzow. Average times error = %0.3g \n', (error_Cadzow_gaussian_global)/tau)
fprintf('Using spline filter. Average times error = %0.3g \n', (error_spline_global)/tau)
fprintf('Using spline filter with Cadzow. Average times error = %0.3g \n', (error_Cadzow_spline_global)/tau)
fprintf('Using spline filter - Direct matrix inversion. Average times error = %0.3g \n', (error_spline_direct_global)/tau)
fprintf('Using E-spline filter - Direct matrix inversion. Average times error = %0.3g \n', (error_Espline_direct_global)/tau)
fprintf('Average sigma_noise of B-spline method = %0.3g\n',sigma_spline_average)
fprintf('Minimal spacing between deltas = %0.3g\n',min(diff(tk)))
%% Save experiment data to file, if "save2file_flag" is on
if save2file_flag
    if solve_espline
        save(strcat('Aperiodic_Noisy_WithEspline_K',int2str(K),'_',int2str(num_of_experiments),'experiments'))
    else
        save(strcat('Periodic_Noisy_K',int2str(K),'_',int2str(num_of_experiments),'experiments'))
    end
end

%% Graphs
FontSize = 8;
LineWidth = 1.5;


if length(SNR_scale)>1
    figure('Name','semilogy(error): SoS vs. previous methods');
    semilogy(SNR_scale,error_3P_global/tau,'LineWidth',LineWidth)
    PlotLabelLog('SoS filter',SNR_scale,error_3P_global/tau,9,'right',FontSize);
    if solve_bspline % show spline only if "solve_bspline" flag is on
        hold all
        semilogy(SNR_scale,error_spline_direct_global/tau,'LineWidth',LineWidth)
        PlotLabelLog('Spline filter',SNR_scale,error_spline_direct_global/tau,9,'right,long',FontSize);
        semilogy(SNR_scale,error_gaussian_global/tau,'LineWidth',LineWidth)
        PlotLabelLog('Gaussian filter',SNR_scale,error_gaussian_global/tau,9,'left,long',FontSize);
        semilogy(SNR_scale,error_Espline_direct_global/tau,'LineWidth',LineWidth)
        PlotLabelLog('E-spline filter',SNR_scale,error_Espline_direct_global/tau,9,'left',FontSize);
    end
    xlabel('SNR [dB]');
    ylabel('Time-delay estimation error [units of \tau]');
    set(gca,'FontSize',FontSize);
    
    if plot_flag
        figure('Name','semilogy(error): E-spline vs. SoS');
        semilogy(SNR_scale,error_3P_global/tau,'LineWidth',LineWidth)
        PlotLabelLog('SoS filter',SNR_scale,error_3P_global/tau,9,'top',FontSize);
        if solve_espline % show E-spline only if "solve_espline" flag is on
            hold all
            semilogy(SNR_scale,error_Espline_direct_global/tau,'LineWidth',LineWidth)
            PlotLabelLog('E-spline filter',SNR_scale,error_Espline_direct_global/tau,9,'right,long',FontSize);
        end
        xlabel('SNR [dB]');
        ylabel('Time-delay estimation error [units of \tau]');
        set(gca,'FontSize',FontSize);
        
    end

end

figure('Name','Estimation vs. original - SoS filter');
stem(tk,xk,'LineWidth',LineWidth,'MarkerEdgeColor','None');
hold all
stem(tk_FRI_3P,xk_FRI_3P,'--','o','LineWidth',0.01,'MarkerEdgeColor','k','MarkerFaceColor','g','MarkerSize',8)
ax = axis;
ax(1:2) = [0 tau];
axis(ax);
xlabel('time [units of \tau]')
ylabel('amplitude')
legend('Original','Estimated')