
// Header
class BayesianMatting
{
public:
	Binary SolveLinearEquations(Real32 a[6][6], Real32 b[6], Real32 x[6]);
	void   Run                 ();
};


// Gauss Seidel
Binary BayesianMatting::SolveLinearEquations(Real32 a[6][6], Real32 b[6], Real32 x[6])
{
	const SInt32 iteration = 10000;
	const Real32 epsilon   = 0.00001f;
	
	Real32 error, s;
	
	for(SInt32 c = 0; c < iteration; ++c)
	{
		error = 0.0f;
		for(SInt32 i = 0; i < 6; ++i)
		{
			s  = b[i];
			s -= a[i][0] * x[0];
			s -= a[i][1] * x[1];
			s -= a[i][2] * x[2];
			s -= a[i][3] * x[3];
			s -= a[i][4] * x[4];
			s -= a[i][5] * x[5];
			s /= a[i][i];

			x[i] += s;

			error += Fabs32(s);
			if(error < epsilon)
			{
				return true;
			}
		}
	}
	return false;
}


// A Bayesian Matting
void BayesianMatting::Run()
{

	const Real32 c  = 0.07f;
	const Real32 cc = c * c;

	const SInt32 w  = 512;
	const SInt32 h  = 512;

	const SInt32 start_u = 0;
	const SInt32 start_v = 0;
	const SInt32 end_u   = w;
	const SInt32 end_v   = h;

	Image Trimap, Original;

	Original.Name = "c:\\redqueen\\original.pfm"; Original.Load();
	Trimap  .Name = "c:\\redqueen\\trimap.pfm" ;  Trimap  .Load();

	CieRgb sub_f, sub_b;
	CieRgb sum_f, sum_b;
	CieRgb avr_f, avr_b;
	SInt32 cnt_f, cnt_b;
	CieRgb col_o, col_t;

	std::vector<Binary> flags(w*h);

	Matrix3x3 cov_f, icv_f;
	Matrix3x3 cov_b, icv_b;

	// Compute Average
	cnt_f = cnt_b = 0;
	sum_f = sum_b = Black;
	for(SInt32 v = start_v; v < end_v; ++v)
	{
		for(SInt32 u = start_u; u < end_u; ++u)
		{

			Original.GetColor(col_o, u, v);
			Trimap  .GetColor(col_t, u, v);
			
			if(col_t.R > 0.95f) { sum_f += col_o; ++cnt_f; }
			if(col_t.R < 0.05f) { sum_b += col_o; ++cnt_b; }
			flags[v*w + u] = (0.05f <= col_t.R) && (col_t.R <= 0.95f);
		}
	}
	avr_f = sum_f / (Real32)cnt_f;
	avr_b = sum_b / (Real32)cnt_b;

	// Compute Covariance
	cov_f.Zero();
	cov_b.Zero();
	cnt_f = cnt_b = 0;
	for(SInt32 v = start_v; v < end_v; ++v)
	{
		for(SInt32 u = start_u; u < end_u; ++u)
		{

			if(flags[v*w + u])
			{
				continue;
			}

			Original.GetColor(col_o, u, v);
			Trimap  .GetColor(col_t, u, v);

			if(col_t.R > 0.95f)
			{
				sub_f = (col_o - avr_f);
				cov_f[0][0] += sub_f.R*sub_f.R; cov_f[0][1] += sub_f.R*sub_f.G; cov_f[0][2] += sub_f.R*sub_f.B;
				cov_f[1][0] += sub_f.G*sub_f.R; cov_f[1][1] += sub_f.G*sub_f.G; cov_f[1][2] += sub_f.G*sub_f.B;
				cov_f[2][0] += sub_f.B*sub_f.R; cov_f[2][1] += sub_f.B*sub_f.G; cov_f[2][2] += sub_f.B*sub_f.B;
				++cnt_f;
			}
			if(col_t.R < 0.05f)
			{
				sub_b = (col_o - avr_b);
				cov_b[0][0] += sub_b.R*sub_b.R; cov_b[0][1] += sub_b.R*sub_b.G; cov_b[0][2] += sub_b.R*sub_b.B;
				cov_b[1][0] += sub_b.G*sub_b.R; cov_b[1][1] += sub_b.G*sub_b.G; cov_b[1][2] += sub_b.G*sub_b.B;
				cov_b[2][0] += sub_b.B*sub_b.R; cov_b[2][1] += sub_b.B*sub_b.G; cov_b[2][2] += sub_b.B*sub_b.B;
				++cnt_b;
			}
		}
	}
	cov_f /= (Real32)cnt_f;
	cov_b /= (Real32)cnt_b;

	if(!cov_f.Inverse(icv_f)) { std::cout << "Cannot Compute Inverse\n"; return; }
	if(!cov_b.Inverse(icv_b)) { std::cout << "Cannot Compute Inverse\n"; return; }

	// Iteration
	Real32 matrix[6][6];
	Real32 vector[6];
	Real32 result[6];

	for(SInt32 i = 0; i < 500; ++i)
	{

		std::cout << "iteration" << i << "\r";

		// Update F & G
		for(SInt32 v = start_v; v < end_v; ++v)
		{
			for(SInt32 u = start_u; u < end_u; ++u)
			{

				if(!flags[v*w + u])
				{
					continue;
				}

				Trimap  .GetColor(col_t, u, v);
				Original.GetColor(col_o, u, v);

				Real32 a  = col_t.R;
				Real32 m  = 1.0f-a;
				Real32 aa = a * a / cc;
				Real32 am = a * m / cc;
				Real32 mm = m * m / cc;
				Real32 acc = a / cc;
				Real32 mcc = m / cc;

				Real32 r = col_o.R;
				Real32 g = col_o.G;
				Real32 b = col_o.B;

				// Matrix
				matrix[0][0] = icv_f[0][0] + aa; matrix[0][1] = icv_f[0][1];      matrix[0][2] = icv_f[0][2];
				matrix[1][0] = icv_f[1][0];      matrix[1][1] = icv_f[1][1] + aa; matrix[1][2] = icv_f[1][2];
				matrix[2][0] = icv_f[2][0];      matrix[2][1] = icv_f[2][1];      matrix[2][2] = icv_f[2][2] + aa;
				matrix[3][3] = icv_b[0][0] + mm; matrix[3][4] = icv_b[0][1];      matrix[3][5] = icv_b[0][2];
				matrix[4][3] = icv_b[1][0];      matrix[4][4] = icv_b[1][1] + mm; matrix[4][5] = icv_b[1][2];
				matrix[5][3] = icv_b[2][0];      matrix[5][4] = icv_b[2][1];      matrix[5][5] = icv_b[2][2] + mm;

				matrix[0][3] = am; matrix[0][4] = 0;  matrix[0][5] = 0;
				matrix[1][3] = 0;  matrix[1][4] = am; matrix[1][5] = 0;
				matrix[2][3] = 0;  matrix[2][4] = 0;  matrix[2][5] = am;
				matrix[3][0] = am; matrix[3][1] = 0;  matrix[3][2] = 0;
				matrix[4][0] = 0;  matrix[4][1] = am; matrix[4][2] = 0;
				matrix[5][0] = 0;  matrix[5][1] = 0;  matrix[5][2] = am;

				// Vector
				vector[0] = icv_f[0][0]*avr_f.R + icv_f[0][1]*avr_f.G + icv_f[0][2]*avr_f.B + r*acc;
				vector[1] = icv_f[1][0]*avr_f.R + icv_f[1][1]*avr_f.G + icv_f[1][2]*avr_f.B + g*acc;
				vector[2] = icv_f[2][0]*avr_f.R + icv_f[2][1]*avr_f.G + icv_f[2][2]*avr_f.B + b*acc;

				vector[3] = icv_b[0][0]*avr_b.R + icv_b[0][1]*avr_b.G + icv_b[0][2]*avr_b.B + r*mcc;
				vector[4] = icv_b[1][0]*avr_b.R + icv_b[1][1]*avr_b.G + icv_b[1][2]*avr_b.B + g*mcc;
				vector[5] = icv_b[2][0]*avr_b.R + icv_b[2][1]*avr_b.G + icv_b[2][2]*avr_b.B + b*mcc;

				// Initialize
				result[0] = avr_f.R;
				result[1] = avr_f.G;
				result[2] = avr_f.B;
				result[3] = avr_b.R;
				result[4] = avr_b.G;
				result[5] = avr_b.B;

				if(!SolveLinearEquations(matrix, vector, result))
				{
					std::cout << "Can't Solve Linear Eq." << std::endl;
				}

				CieRgb C(        r,         g,         b);
				CieRgb F(result[0], result[1], result[2]);
				CieRgb B(result[3], result[4], result[5]);

				a = ((C-B).Dot(F-B) / (F-B).Dot(F-B));
				a = (a > 1.0f) ? 1.0f : a;
				a = (a < 0.0f) ? 0.0f : a;

				// Update Alpha
				Trimap.SetColor(CieRgb(a, a, a), u, v);

			}
		}

	}

	// Save Alpha Image
	Trimap.Name = "c:\\redqueen\\alpha.pfm";
	Trimap.Save(32/*bit depth*/);

}