% This script runs BOMP analysis on test data. It plots the retrieved
% frequencies according to the retreived blocks, all retrieved amplitudes
% of the vectors in those blocks (including the amlitudes of the diagonal
% and overtone peaks) and the retrieved cross peak magnitudes. Finally, it
% reconstructs a clean version of the full spectrum with high resolution,
% as well as a version with the cross peaks only.

clear all;
close all;
clc

tic
% load data
load testData
load TVec;
timeDomainDataMat = testData;

% constants
PSEC = 1e3;
FSEC = 1;
CM = 1e-2;
NM = 1e-9;
C = 3e-7;

GAMMA = C*(5.5*(CM)^-1); % Vibrational level homogenous width, CAN BE AN ARRAY OF VALUES
LAMBDA_0 = 795*NM;
LAMBDA_FWHM = 3.75*NM;
T_FWHM = 3.5*PSEC; % gaussian convoloved with line. Includes effects of inhomogenous broadening AND convovling with pulse spectrum. 
                   % The convlution is largely conteracted by the multiplication
Tp = 25*FSEC;

BLK_SPR = 3;
MAX_OVERTONE = 8;
MAX_OVERTONE_PHI = 8;
MAX_OVERTONE_DIAG = 4;
BLK_SIZE = MAX_OVERTONE + MAX_OVERTONE_PHI + MAX_OVERTONE_DIAG;
SMOOTH_PARAM = 400;

% Derived variables
dT = TVec(2)-TVec(1);
numPts = length(TVec);
TSpan = (numPts - 1)*dT;
dF = 1/TSpan; 
FMax = numPts/2*dF;
FVec = 0:dF:FMax;
FVec = [-fliplr(FVec),FVec(2:end)];
[TMat1,TMat2] = meshgrid(TVec,TVec);
FFwhm = sqrt(2)*C*LAMBDA_FWHM/LAMBDA_0^2; % gaussian multiplied with line. This is multiplication of line with the pulse spectrum squared.

timeDomainData = reshape(timeDomainDataMat,length(TVec)^2,1);
% smTimeDomainDataMat = smoothn(timeDomainDataMat,SMOOTH_PARAM);  % For removing a slowly varying component, 
                                                                % usually due to slow drifts in the laser power.
                                                                % if necessary, use with cancelartifactmatrix

% Generate dictionary in three parts: 1. dictionary contains all terms
% that oscillate in the T1 and T2 direction, where T1 is along the x-axis
% of the data matrix and T2 is along the y-axis. dictionaryPhi contains all 
% the same terms, except the oscillation in T2 are pi/2 out of phase with 
% those in T1. 3. dictionaryDiag contains all terms that oscillate in the 
% T1 - T2 direction (and T1 + T2 direction, if necessary)
[dictionary, FVecBOMP] = gendictionary(TVec, GAMMA, FFwhm, Tp, T_FWHM, MAX_OVERTONE, 0);
dictionaryPhi = gendictionary(TVec, GAMMA, FFwhm, Tp, T_FWHM, MAX_OVERTONE_PHI, pi/2);
dictionaryDiag = gendiagonaldictionary(TVec, GAMMA, FFwhm, Tp, T_FWHM, MAX_OVERTONE_DIAG);

% Weave the three dictionaries into the correct block form
dictionaryTot = zeros(size(dictionary,1),size(dictionary,2)+size(dictionaryDiag,2)+size(dictionaryPhi,2));

count = 0;
countPhi = 0;
countDiag = 0;

for i=1:size(dictionaryTot,2)
    if rem(i-1,BLK_SIZE) < MAX_OVERTONE
        count = count + 1;
        dictionaryTot(:,i) = dictionary(:,count);
    elseif rem(i-1,BLK_SIZE) < MAX_OVERTONE + MAX_OVERTONE_DIAG
        countDiag = countDiag + 1;
        dictionaryTot(:,i) = dictionaryDiag(:,countDiag);
    else
        countPhi = countPhi + 1;
        dictionaryTot(:,i) = dictionaryPhi(:,countPhi);
    end
end
disp('Finished creating dictionary, starting artifact removal');

% Artifact removal - May be used with experimental data with known
% artifacts. As currently implemented takes 2-3 minutes on regular PC
% clDict = cancelartifactmatrix(TMat1, TMat2, Tp, dictionaryTot, smTimeDomainDataMat);
% clTimeDomainData = cancelartifactmatrix(TMat1, TMat2, Tp, timeDomainData, smTimeDomainDataMat);

clDict = dictionaryTot;
clTimeDomainData = timeDomainData;

% Now that the block dictionary is prepared, run the BOMP algorithm
disp('Starting BOMP');
[x,r,normR,residHist,errHist] = BOMP(clDict, clTimeDomainData, BLK_SIZE, BLK_SPR);
chosenDictInd = find(x~=0);

% Optimize the magnitudes of the components in the selected blocks
[coeffs,stats] = robustfit(clDict(:,chosenDictInd),clTimeDomainData,'cauchy',2.385,'off');

% Calculate residual
timeDomainDataFit = clDict(:,chosenDictInd)*coeffs;
residual = clTimeDomainData - timeDomainDataFit;
mse = mean(residual.^2);
disp(['1000*MSE ',num2str(1000*mse)]);

% Plot the results of all selected frequencies and magnitudes 
figure, plot(FVecBOMP/(C*100),x(1:BLK_SIZE:end));
hold on, plot(FVecBOMP/(C*100),x(2:BLK_SIZE:end),'r'); 
hold on, plot(FVecBOMP/(C*100),x(3:BLK_SIZE:end),'g'); 
hold on, plot(FVecBOMP/(C*100),x(4:BLK_SIZE:end),'m'); 
hold on, plot(FVecBOMP/(C*100),x(5:BLK_SIZE:end),'k');
hold on, plot(FVecBOMP/(C*100),x(6:BLK_SIZE:end),'c');
hold on, plot(FVecBOMP/(C*100),x(7:BLK_SIZE:end),'bo'); 
hold on, plot(FVecBOMP/(C*100),x(8:BLK_SIZE:end),'ro');
hold on, plot(FVecBOMP/(C*100),x(9:BLK_SIZE:end),'go');
hold on, plot(FVecBOMP/(C*100),x(10:BLK_SIZE:end),'mo'); 
hold on, plot(FVecBOMP/(C*100),x(11:BLK_SIZE:end),'ko');
hold on, plot(FVecBOMP/(C*100),x(12:BLK_SIZE:end),'co');

legend('Corr w/ DC','Corr 2w/ DC', 'Corr w/ w','Corr w/ 2w','Corr w/ 3w','Corr w/ w/2', ...
    'Corr 2w/ 2w', 'Corr 3w/ 3w','Diag Art','Corr diag w/ DC', 'Corr diag 2w/ DC', 'Corr diag 3w/ DC');
title('Retrieved amplitudes and frequencies of all strong componenets');

figure, plot(FVecBOMP/(C*100),x(13:BLK_SIZE:end));
hold on, plot(FVecBOMP/(C*100),x(14:BLK_SIZE:end),'r'); 
hold on, plot(FVecBOMP/(C*100),x(15:BLK_SIZE:end),'g'); 
hold on, plot(FVecBOMP/(C*100),x(16:BLK_SIZE:end),'m'); 
hold on, plot(FVecBOMP/(C*100),x(17:BLK_SIZE:end),'k');
hold on, plot(FVecBOMP/(C*100),x(18:BLK_SIZE:end),'c');
hold on, plot(FVecBOMP/(C*100),x(19:BLK_SIZE:end),'bo'); 
hold on, plot(FVecBOMP/(C*100),x(20:BLK_SIZE:end),'ro');
legend('Corr w/ DC','Corr 2w/ DC', 'Corr w/ w','Corr w/ 2w','Corr w/ 3w','Corr w/ w/2', ...
    'Corr 2w/ 2w', 'Corr 3w/ 3w');
title('Retrieved amplitudes and frequencies of all strong componenets with \phi=\pi/2');

% Now that we are done removing the strong compoenents from the data,
% retrieve cross-peak values from the residual
coeffsSprs = x;
blockCoeffs = coeffsSprs(1:BLK_SIZE:end);
binaryCoeffs = blockCoeffs~=0;
selectedFs = nonzeros(FVecBOMP.*binaryCoeffs');

CrossTerms = gencrossterms(TVec, GAMMA, FFwhm, T_FWHM, selectedFs, 0, pi);
[Ccoeffs,Cstats] = robustfit(CrossTerms,residual,'cauchy',2.385,'off');
CTimeDomainDataFit = CrossTerms*Ccoeffs;

% Calculate final residual
CResidual = residual - CTimeDomainDataFit;
Cmse = mean(CResidual.^2);
disp(['1000*CMSE ',num2str(1000*Cmse)]);

% Plot retrieved cross-peak magnitudes
figure, plot(abs(Ccoeffs));
set(gca, 'XTick', 1:3,'XTicklabel',{'line 1 & line 2','line 1 & line 3','line 2 & line 3'});
title('Cross-peak magnitudes');

%% Plot a "Clean" 2D spectrum, can be run separetly from rest of script provided that
% that coeffs, Ccoeffs and selectedFs were saved in the folder of this script

% Uncomment for running this part individually
% clear all;
% 
% load coeffs;
% load Ccoeffs;
% load selectedFs;
% load TVec;

PSEC = 1e3;
CM = 1e-2;
C = 3e-7;

MAX_OVERTONE = 8;
MAX_OVERTONE_PHI = 8;
MAX_OVERTONE_DIAG = 4;
BLK_SIZE = MAX_OVERTONE + MAX_OVERTONE_PHI + MAX_OVERTONE_DIAG;

GAMMA = C*(5.5*(CM)^-1);
FFwhm = sqrt(2)/(3.5*PSEC); 

% Create display time and frequency vectors, with very high resolution
resFactor = 10;
dTOrg = TVec(2)-TVec(1);
TVecDisp = 0:dTOrg:resFactor*TVec(end)-dTOrg;
dTDisp = TVecDisp(2)-TVecDisp(1);
numPtsDisp = length(TVecDisp);
TSpanDisp = (numPtsDisp - 1)*dTDisp;
dFDisp = 1/TSpanDisp; 
FMaxDisp = numPtsDisp/2*dFDisp;
FVecDisp = 0:dFDisp:FMaxDisp;
FVecDisp = [-fliplr(FVecDisp(2:end)),FVecDisp];

% Reconstruct "strong" peaks - diagonal peaks and correlation peaks between
% the fundmanetal molecular frequency and the first overtone
vectorBank = zeros(length(TVecDisp).^2,length(selectedFs));
vectorBankOv = zeros(length(TVecDisp).^2,length(selectedFs));
lineShape = zeros(length(TVecDisp),length(selectedFs));
lineShapeOv = zeros(length(TVecDisp),length(selectedFs));

for i=1:length(selectedFs)
    % Fundamental peaks (i.e. correlation of w with w)
    inhomogLineShape = -1i*(1./(FVecDisp.^2 - selectedFs(i)^2 - 2*1i*FVecDisp*GAMMA));
    GaussianF = exp(-2*(FVecDisp/(FFwhm/sqrt(log(2)))).^2);
    lineShapeNotNorm = fft_HF(ifft_HF(real(inhomogLineShape)).*ifft_HF(GaussianF))'; %Without homogenous broadening
    lineShape(:,i) = lineShapeNotNorm./sqrt(sum(abs(lineShapeNotNorm).^2));
    FIDMat = repmat(lineShape(:,i)',length(TVecDisp),1).*repmat(lineShape(:,i),1,length(TVecDisp));
    vectorBank(:,i) = reshape(FIDMat,length(TVecDisp).^2,1);
    
    % Overtone peaks (i.e. correlation of w with 2w)
    if i == 1 % Only use the first overtone of 219 because others fold
    % Correlation of fundamental with first overtone
    inhomogLineShapeOv = -1i*(1./(FVecDisp.^2 - (2*selectedFs(i))^2 - 2*1i*FVecDisp*GAMMA));
    lineShapeOvNotNorm = fft_HF(ifft_HF(real(inhomogLineShapeOv)).*ifft_HF(GaussianF))'; %Without homogenous broadening
    lineShapeOv(:,i) = lineShapeOvNotNorm./sqrt(sum(abs(lineShapeOvNotNorm).^2));
    FIDMatOv = repmat(lineShapeOv(:,i)',length(TVecDisp),1).*repmat(lineShape(:,i),1,length(TVecDisp)) + ...
        repmat(lineShape(:,i)',length(TVecDisp),1).*repmat(lineShapeOv(:,i),1,length(TVecDisp));
    vectorBankOv(:,i) = reshape(FIDMatOv,length(TVecDisp).^2,1);
    end
end

diagPeakReconF = vectorBank*(coeffs(3:BLK_SIZE:end));
ovPeakReconF = vectorBankOv*(coeffs(4:BLK_SIZE:end));
diagPeakReconFMat = reshape(diagPeakReconF,length(TVecDisp),length(TVecDisp));
ovPeakReconFMat = reshape(ovPeakReconF,length(TVecDisp),length(TVecDisp));

% Reconstruct cross peaks
CvectorBank = zeros(length(TVecDisp)^2,(length(selectedFs)^2-length(selectedFs))/2);

PHI = pi;
cosPhi = cos(PHI);
sinPhi = sin(PHI);

count = 0;
for i = 1:length(selectedFs)
    for j = i+1:length(selectedFs) % The j fequency is always the larger one
        count = count + 1;
        lineShapePhii = (cosPhi*real(lineShape(:,i))+sinPhi*imag(lineShape(:,i))) + ...
            1i*((-sinPhi*real(lineShape(:,i))+cosPhi*imag(lineShape(:,i))));
        lineShapePhij = (cosPhi*real(lineShape(:,j))+sinPhi*imag(lineShape(:,j))) + ...
            1i*((-sinPhi*real(lineShape(:,j))+cosPhi*imag(lineShape(:,j))));
        CFIDMat = repmat(lineShape(:,i)',length(TVecDisp),1).*repmat(lineShapePhij,1,length(TVecDisp)) + ...
            repmat(lineShape(:,j)',length(TVecDisp),1).*repmat(lineShapePhii,1,length(TVecDisp));
        CvectorBank(:,count) = reshape(CFIDMat,length(TVecDisp)^2,1);
    end
end

crossPeakReconF = CvectorBank*(Ccoeffs);
crossPeakReconFMat = reshape(crossPeakReconF,length(TVecDisp),length(TVecDisp));
ReconFMat = diagPeakReconFMat + ovPeakReconFMat + crossPeakReconFMat;

% Contour plot of fully reconstructed signal
figure,contour(FVecDisp/(C*100),FVecDisp/(C*100),abs(ReconFMat));
xlim([100 550]);
ylim([100 550]);
set(gca, 'XTick', 100:50:550,'XTicklabel',{'100','','200','','300','','400','','500',''});
set(gca, 'YTick', 100:50:550,'YTicklabel',{'100','','200','','300','','400','','500',''});
title('Fully reconstructed spectrum with 10x resolution');
xlabel('\omega_1 [cm^{-1}]'); ylabel('\omega_2 [cm^{-1}]');

% Contour plot of cross-peak signal component only
figure,contour(FVecDisp/(C*100),FVecDisp/(C*100),abs(crossPeakReconFMat));
xlim([100 550]);
ylim([100 550]);
set(gca, 'XTick', 100:50:550,'XTicklabel',{'100','','200','','300','','400','','500',''});
set(gca, 'YTick', 100:50:550,'YTicklabel',{'100','','200','','300','','400','','500',''});
title('Cross peaks only with 10x resolution');
xlabel('\omega_1 [cm^{-1}]'); ylabel('\omega_2 [cm^{-1}]');
toc