// The current me.sci file loads the functions in the CCA toolbox to compute the
// Legendre-Fenchel transform. It uses lft.sci for conjugate computation
// See the help files for more information on the functions

///////////////////// DIRECT COMPUTATION ///////////BEGIN
//Moreau Envelope direct computation using matrix arithmetic (quadratic cost)
//for comparison purposes only
function [M,p,P] = me_direct(X, f, S) 
	//all vectors are column vectors; requires O(n^2) storage
	t1 = X * ones(1,size(S,1));//size n x m
	t2 = ones(size(X,1),1) * S';//size n x m
	t3 = (t1 - t2).^2;//||X(i) - S(j)||^2
	
	t4 = f * ones(1,size(S,1));//size n x m
	t=t3+t4;//||X(i) - S(j)||^2 + f(i,j)
	[ M, p ] = min( t, 'r')';//only gives a selection in P not all argmin
	M=M';

	[outargn ,inargn] = argn();
	if (outargn > 1) //skip computation of the full prox mapping when not required
		n=size(X,1);m=size(S,1);
		P = zeros(n,m);
		for j=1:1:m,
			tmp = find(t(:,j)==M(j),-1)';//find the full argmin
			P(:,j) = tmp(1);//fill all columns
			P(1:size(tmp,1),j) = tmp;//case argmin multivalued, fill first rows
		end;
	end;
endfunction

//assume n,m positive integer, and test whether sizef=[n,m]
//internal function
function [b,s]=checksize(n,m,sizef,functionname)
	b = n==sizef(1) & m==sizef(2);s='';
	if ~b
		if n~=sizef(1)
			s=sprintf("Unconsistent size of f in %s. size(Xr)=%i, but size(f,1)=%i",functionname,n,sizef(1));
		end;
		if m~=sizef(2)
			s=sprintf("Unconsistent size of f in %s. size(Xc)=%i but size(f,2)=%i",functionname,m,sizef(2));
		end;
		
	end
endfunction

//Moreau envelope of a bivariate function
//for comparison purposes only
function M = me_direct2d (Xr, Xc, f, Sr, Sc)
	//O(N^(3/2)) = O(n^3) complexity with N=n^2 and n=length(Xr)=length(Xc)=length(Sr)=length(Sc)
	//we use the separability of the Euclidean norm to reduce from n^4 to n^3
	[b,s]=checksize(length(Xr),length(Xc),size(f),'me_direct2d');
	if ~b, error(s);end;
	r1 = ones(size(Xr, 1), size(Sc, 1));
	for i=1:length(Xr)
		F = me_direct(Xc, f(i,:)', Sc);//quadratic cost
		r1(i,:) = F';//linear cost since F is a vector
	end
	M = ones(size(Sr, 1), size(Sc, 1));
	for i=1:size(Sc, 1)
		M(:,i) = me_direct(Xr, r1(:,i), Sr);//quadratic cost
	end
endfunction
///////////////////// DIRECT COMPUTATION ///////////END

///////////////////// BRUTE COMPUTATION ///////////BEGIN
//for comparison purposes only
function M = me_brute2d (Xr, Xc, f, Sr, Sc)
	//quadratic complexity O(n^4) in n when n=length(Xr)=length(Xc)=length(Sr)=length(Sc)
	[b,s]=checksize(length(Xr),length(Xc),size(f),'me_direct2d');
	if ~b, error(s);end;
	[n1,n2]=size(f);m1=length(Sr);m2=length(Sc);
	M=zeros(m1,m2);
	for p1=1:m1
		for p2=1:m2
			//Compute phi(p1,p2,:,:)
			t1 = (Xr-Sr(p1)).^2 * ones(1,n2);//matrix with same column vector (p1-q1).^2 for all q1 size n1*n2
			t2 = ones(n1,1) * ((Xc-Sc(p2)).^2)';//matrix with same row vector (p2-q2).^2 for all q2 size n1*n2
			t = t1 + t2 + f;//matrix(q1,q2) = (p1-q1)^2 + (p2-q2)^2 + f(q1, q2) size n1*n2
			M(p1,p2) = min(t);//quadratic cost O(n^2)
		end;
	end;
endfunction
///////////////////// BRUTE COMPUTATION ///////////END


///////////////////// NONEXPANSIVE PROXIMAL MAPPING ///////////BEGIN
// Moreau Envelope computation using monotone nonexpansive proximal mapping (linear-time)
// only works for CONVEX functions
function [M,P] = me_nep(X, f, S)
	// step 1. calculate the MYE for S(1) (binary search)
	min_val = %inf;
	i = 1;
	m = size(S, 1);n=size(X,1);
	function phi = v(i), phi = (S(1)-X(i))^2 + f(i), endfunction;
	//v = (S(1) - X) .^ 2 + f;
	a = 1; b = n;
	while (b - a > 1)
		mid = round((b + a) / 2);
		l = round((a + mid) / 2);
		r = round((mid + b) / 2);
		
		if (v(mid) < min(v(l), v(r)))
			a = l; b = r;
		elseif (min(v(l), v(r)) < min(v(a), v(b)))
			if (v(l) < v(r))
				b = mid;
			elseif (v(l) > v(r))
				a = mid;
			end
		else
			if (v(a) < v(b))
				b = l;
			else
				a = r;
			end
		end
	end
	
	if (v(b) > v(a))
		i = a;
	else
		i = b;
	end
	M(1) = v(i);P(1)=i;
	// step 2. iterate through the rest of f and build M based on increasing indices
	for j=2:m
		val_1 = (X(i) - S(j)) ^ 2 + f(i);
		if (i < n)
			val_2 = (X(i+1) - S(j)) ^ 2 + f(i+1);
		else
			val_2 = %inf;
		end
		
		if (val_1 < val_2)
			M(j) = val_1;P(j)=i;
		else
			M(j) = val_2;P(j)=i+1;
			if (i < n),i = i + 1;end;
		end
	end
endfunction

//Moreau envelope of a bivariate convex function
function M = me_nep2d(Xr, Xc, f, Sr, Sc)
	[b,s]=checksize(length(Xr),length(Xc),size(f),'me_nep2d');
	if ~b, error(s);end;
	M = zeros(size(Sc, 1), size(Sr, 1));
	MY = M;
	// rows first
	for i=1:size(f, 1)
		MY(:,i) = me_nep(Xr, f(i,:)', Sr);//avoid transpose, same as MY(i, :) = me_nep(Xr, f(i,:)', Sr)';
	end
	// columns next	
	for j=1:size(Sr, 1)
		M(:, j) = me_nep(Xc, MY(j,:), Sc);//avoid transpose, same as M(:, j) = me_nep(Xc, MY(:, j), Sc);
	end
endfunction
///////////////////// NONEXPANSIVE PROXIMAL MAPPING ///////////END

///////////////////// DISCRETE LEGENDRE TRANSFORM using LLT ///////////BEGIN
//Main function: Moreau envelope of univariate function (not necessarily convex)
function M = me_llt(X,f,S,fusionopt)
	g = (X.^2 + f)/2;
	if (argn(2) <= 3) 
		fusionopt=1;
	end
	Conj = lft_llt(X,g,S,fusionopt);
	M = S.^2 - 2 * Conj;
endfunction

//Main function: Moreau envelope of bivariate function (not necessarily convex)
function [M,g,Conjpartial,Conj] = me_llt2d(Xr, Xc, f, Sr, Sc)
	[b,s]=checksize(length(Xr),length(Xc),size(f),'me_llt2d');
	if ~b, error(s);end;
	// step 1.- compute g for Moreau Envelope
// 	for i = 1:size(f, 1)
// 		for j = 1:size(f, 2)
// 			g(i, j) = (Xr(i)*Xr(i) + Xc(j)*Xc(j) + f(i, j)) / 2;
// 		end
// 	end
	[X,Y]=ndgrid(Xr.^2,Xc.^2);g=(X+Y+f)/2;clear X Y;//equivalent to above loop
	
	// step 2.- get the partial conjugates by row
	Conjpartial = ones(size(Xr,1),size(Sc,1));
	if (length(Xc) > 1)
		for i = 1:size(Xr,1)
			Conjpartial(i,:) = -1 .* lft_llt(Xc, g(i,:)', Sc)';
		end
	else
		for i = 1:size(Sr,1)
			Conjpartial(:,i) = g(:, 1);
		end
	end

	// step 3.- get the partial conjugates by column
	Conj = ones(size(Sr,1),size(Sc,1));
	for j = 1:size(Sc,1)
		Conj(:,j) = lft_llt(Xr, Conjpartial(:,j), Sr);
	end

	// step 4.- get Moreau Envelope
// 	for i = 1:size(Sr,1)
// 	  for j = 1:size(Sc,1)
// 	    M(i,j) = Sr(i)*Sr(i) + Sc(j)*Sc(j) - 2 * Conj(i,j);
// 	  end
// 	end
	[Sx,Sy]=ndgrid(Sr.^2,Sc.^2);M=Sx+Sy - 2 * Conj;clear Sx;clear Sy;//equivalent to above loop
endfunction
///////////////////// DISCRETE LEGENDRE TRANSFORM using LLT ///////////END


///////////////////// Parabolic Envelope ///////////BEGIN

//internal function: Parabolic envelope of a function
//represented as a matrix. Useful for computing distance transforms
function out=me_pe_d(f)
// lower bound of f definition is 0, upper is n-1
	// asterisked comments (//*) are as in the alg at Felzenswalb, p. 7
	n = length(f);
	k=1 //* Index of rightmost parabola in lower envelope
	v(1)=1 //* Locations of parabolas in lower envelope
	z(1)=-%inf //* Locations of boundaries between parabolas
	z(2)=%inf
	for q=2:n //* Compute lower envelope
	        // find intersection between para at x=q and para at x=v(k)
		// (we're searching for the leftmost lower envelope parabola that intersects one at (q,f(q))
		// at a point to the left of the lower envelope's last intersection
		
		qinit=%t // if this is the first time through for this q value, just calculate s and continue
		if ~isinf(f(q)) // if f(q) is inf, no intersection of the parabola will be suitable
			while qinit | ((s <= z(k)) & (~isinf(s) | s>0))
				if qinit
					qinit=%f
				else
					k=k-1
				end
				s=( (f(q)+q^2)-(f(v(k))+(v(k))^2) )/(2*q-2*v(k))
				sinf=%t
				if(f(q)==%inf)
					if(f(v(k))==%inf)//both values infinite						
						s=(q+v(k))/2
					else//f(q) infinite
					end
				else
					if (f(v(k))==%inf)//f(v(k)) infinite, s=%inf
					else
						sinf=%f
					end
				end
				if sinf
					if s <= z(k)
						str="<="
					else
						str=">"
					end
				end
						
			end
			k=k+1
			v(k)=q
			z(k)=s
			z(k+1)=%inf
		end
	end
	k=1
	for q = 1:n //* Fill in values of distance transform
		while z(k+1) < q
			k=k+1
		end
		out(q)=(q-v(k))^2+f(v(k))
	end
endfunction

//main function: parabolic envelope of a univariate function on
//an interval
function [M] = me_pe(x0, xn, f)
	//X must be a regular grid and S=X
	n=length(f);h=(xn-x0)/(n-1);
	X=(x0:h:xn)';n=size(X,1);
	hsqr=h^2;
	//recover non integer partition 
	//using M(x_j)=h^2 * min_i [ ||j-i||^2 + f(x_0+i*h)/h^2]
	M=hsqr * me_pe_d(f./hsqr);
endfunction

/// Map a function to each of the rows of a matrix, putting the resulting row vectors into the output matrix
/// @param func A function that takes two parameters (the first is a row vector, 
//         the second the number of the current row) and returns a row vector
function out=rowMap(func,m)
	out=m
	s=size(m)
	n=s(1)
	for i=1:n
		out(i,:)=func(m(i,:),i)
	end
endfunction

function out=colMap(func,m)
	out=m
	s=size(m)
	n=s(2)
	for i=1:n
		out(:,i)=func(m(:,i),i)
	end
endfunction

//internal function: apply pe on a (2d matrix)
function out=me_pe2d_d(hx, hy, m)
	function ret=do_pe(colvect,colnum)
		ret=me_pe_d(colvect)
	endfunction
	hxsqr=hx^2;hysqr=hy^2;
	out=hxsqr * colMap(do_pe,(hysqr/hxsqr) * colMap(do_pe,m'/hysqr)' )
endfunction

//main function: parabolic envelope of a bivariate function on
//a grid
function [M] = me_pe2d(x0, xn, y0, yn, f)
	//X must be a regular grid and S=X
	nx=size(f,1);ny=size(f,2);
	hx=(xn-x0)/(nx-1);
	hxsqr=hx^2;
	hy=(yn-y0)/(ny-1);
	hysqr=hy^2;
	//recover non integer partition 
	//using M(x_j)=h^2 * min_i [ ||j-i||^2 + f(x_0+i*h)/h^2]
	M= me_pe2d_d(hx,hy,f);
endfunction


///////////////////// Parabolic Envelope ///////////END
