#include <rumba/manifoldmatrix.h>


using RUMBA::Manifold;
using RUMBA::ManifoldMatrix;

ManifoldMatrix
RUMBA::makeProductManifold(const Manifold<double>& left, const Manifold<double>& right)
{
	intPoint result_dim;

	int ldim = 0;
	int rdim = 0;
	bool result_transposed = false;

	if ( left.width() > 1 ) ldim++;
	if ( left.height() > 1 ) ldim++;
	if ( left.depth() > 1 ) ldim++;

	if ( right.width() > 1 ) rdim++;
	if ( right.height() > 1 ) rdim++;
	if ( right.depth() > 1 ) rdim++;



	if (ldim >1)
	{
		result_dim = left.extent();
		result_dim.t = right.extent().t;
	}
	else if ( rdim >1 )
	{
		result_dim = right.extent();
		result_dim.t = right.extent().t;
		result_transposed = true;
	}
	else
	{
		result_dim = left.extent();
		result_dim.t = right.extent().t;
	}

	Manifold<double> result (result_dim);

	if ( !result_transposed )
		return ManifoldMatrix (result,  result.begin(), result.extent().x * result.extent().y * result.extent().z, result.extent().t, false );
	else
		return ManifoldMatrix (result,  result.begin(), result.extent().t, result.extent().x * result.extent().y * result.extent().z, true );

}

ManifoldMatrix RUMBA::manifold_generator::operator() (const ManifoldMatrix& left, const ManifoldMatrix& right)
{
	return makeProductManifold(left.M,right.M) ;
}

ManifoldMatrix RUMBA::manifold_generator::operator() (const ManifoldMatrix& my_matrix)
{
	Manifold<double> x (my_matrix.M.extent());
	return ManifoldMatrix(x,x.begin(),my_matrix.rows(),my_matrix.cols(),my_matrix.Transpose);
}

#ifdef TEST_MATRIX
int main()
{
	Manifold<double> M(intPoint(2,1,1,2));
	M[0] = 9;
	M[1] = -2;
	M[2] = -1;
	M[3] = 3;

	ManifoldMatrix mM (M,M.begin(),2,2,false);
	cout << "Got here" << endl;

	Manifold<double> N(intPoint(2,1,1,3));
	N[0] = 5;
	N[1] = 0;
	N[2] = 5;
	N[3] = 1;
	N[4] = 2;
	N[5] = 3;

	ManifoldMatrix mN (N,N.begin(),2,3,false);
	cout << "Got here" << endl;

	// call multiply
	ManifoldMatrix mP = multiply(mM, mN);

	// or just use operator*
	mP = mM * mN;

//	RUMBA::LU_Functor<ManifoldMatrix> x(mM);
//	cout << x.determinant() << endl;
//	std::vector<double> v;
//	v.push_back(-13); v.push_back(14);
//	x.solve(v);
//	cout << "Det: " << x.determinant() << endl;
//	cout << "Solution: " << v[0] << " " << v[1] << endl;
//	cout << mM.element(0,0) << " " <<  mM.element(0,1) << endl;
//	cout << mM.element(1,0) << " " <<  mM.element(1,1) << endl;
	ManifoldMatrix mI = invert(mM);
	cout << mI.element(0,0) << " " <<  mI.element(0,1) << endl;
	cout << mI.element(1,0) << " " <<  mI.element(1,1) << endl;


	cout << "Got here: end" << endl;
}

#endif
