#include "dense.h"
#include "Mutils.h"

/** 
 * Perform a left cyclic shift of columns j to k in the upper triangular
 * matrix x, then restore it to upper triangular form with Givens rotations.
 * The algorithm is based on the Fortran routine DCHEX from Linpack.
 *
 * The lower triangle of x is not modified.
 * 
 * @param x Matrix stored in column-major order
 * @param ldx leading dimension of x
 * @param j column number (0-based) that will be shifted to position k
 * @param k last column number (0-based) to be shifted
 * @param cosines cosines of the Givens rotations
 * @param sines sines of the Givens rotations
 * 
 * @return 0 for success
 */
static
int left_cyclic(double x[], int ldx, int j, int k,
		double cosines[], double sines[])
{
    double *lastcol;
    int i, jj;

    if (j >= k)
	error("incorrect left cyclic shift, j (%d) >= k (%d)", j, k);
    if (j < 0)
	error("incorrect left cyclic shift, j (%d) < 0", j, k);
    if (ldx < k)
	error("incorrect left cyclic shift, k (%d) > ldx (%d)", k, ldx);
    lastcol = (double*) R_alloc(k+1, sizeof(double));
				/* keep a copy of column j */
    for(i = 0; i <= j; i++) lastcol[i] = x[i + j*ldx];
				/* For safety, zero the rest */
    for(i = j+1; i <= k; i++) lastcol[i] = 0.;
    for(jj = j+1; jj <= k; jj++) { /* columns to be shifted */
	int diagind = jj*(ldx+1), ind = (jj-j) - 1;
	double tmp = x[diagind], cc, ss;
				/* Calculate the Givens rotation. */
				/* This modified the super-diagonal element */
	F77_CALL(drotg)(x + diagind-1, &tmp, cosines + ind, sines + ind);
	cc = cosines[ind]; ss = sines[ind];
				/* Copy column jj+1 to column jj. */
	for(i = 0; i < jj; i++) x[i + (jj-1)*ldx] = x[i+jj*ldx];
				/* Apply rotation to columns up to k */
	for(i = jj; i < k; i++) {
	    tmp = cc*x[(jj-1)+i*ldx] + ss*x[jj+i*ldx];
	    x[jj+i*ldx] = cc*x[jj+i*ldx] - ss*x[(jj-1)+i*ldx];
	    x[(jj-1)+i*ldx] = tmp;
	}
				/* Apply rotation to lastcol */
	lastcol[jj] = -ss*lastcol[jj-1]; lastcol[jj-1] *= cc;
    }
				/* Copy lastcol to column k */
    for(i = 0; i <= k; i++) x[i+k*ldx] = lastcol[i];
    return 0;
}

static
SEXP getGivens(double x[], int ldx, int jmin, int rank)
{
    int shiftlen = (rank - jmin) - 1;
    SEXP ans = PROTECT(allocVector(VECSXP, 4)), nms, cosines, sines;

    SET_VECTOR_ELT(ans, 0, ScalarInteger(jmin));
    SET_VECTOR_ELT(ans, 1, ScalarInteger(rank));
    SET_VECTOR_ELT(ans, 2, cosines = allocVector(REALSXP, shiftlen));
    SET_VECTOR_ELT(ans, 3, sines = allocVector(REALSXP, shiftlen));
    setAttrib(ans, R_NamesSymbol, nms = allocVector(STRSXP, 4));
    SET_STRING_ELT(nms, 0, mkChar("jmin"));
    SET_STRING_ELT(nms, 1, mkChar("rank"));
    SET_STRING_ELT(nms, 2, mkChar("cosines"));
    SET_STRING_ELT(nms, 3, mkChar("sines"));
    if (left_cyclic(x, ldx, jmin, rank - 1, REAL(cosines), REAL(sines)))
	error("Unknown error in getGivens");
    UNPROTECT(1);
    return ans;
}

SEXP checkGivens(SEXP X, SEXP jmin, SEXP rank)
{
    SEXP ans = PROTECT(allocVector(VECSXP, 2)),
	Xcp = PROTECT(duplicate(X));
    int  *Xdims;

    if (!(isReal(X) & isMatrix(X)))
	error("X must be a numeric (double precision) matrix");
    Xdims = INTEGER(coerceVector(getAttrib(X, R_DimSymbol), INTSXP));
    SET_VECTOR_ELT(ans, 1, getGivens(REAL(Xcp), Xdims[0],
				     asInteger(jmin), asInteger(rank)));
    SET_VECTOR_ELT(ans, 0, Xcp);
    UNPROTECT(2);
    return ans;
}

SEXP lsq_dense_Chol(SEXP X, SEXP y)
{
    SEXP ans;
    int info, n, p, k, *Xdims, *ydims;
    double *xpx, d_one = 1., d_zero = 0.;

    if (!(isReal(X) & isMatrix(X)))
	error("X must be a numeric (double precision) matrix");
    Xdims = INTEGER(coerceVector(getAttrib(X, R_DimSymbol), INTSXP));
    n = Xdims[0];
    p = Xdims[1];
    if (!(isReal(y) & isMatrix(y)))
	error("y must be a numeric (double precision) matrix");
    ydims = INTEGER(coerceVector(getAttrib(y, R_DimSymbol), INTSXP));
    if (ydims[0] != n)
	error(
	    "number of rows in y (%d) does not match number of rows in X (%d)",
	    ydims[0], n);
    k = ydims[1];
    if (k < 1 || p < 1) return allocMatrix(REALSXP, p, k);
    ans = PROTECT(allocMatrix(REALSXP, p, k));
    F77_CALL(dgemm)("T", "N", &p, &k, &n, &d_one, REAL(X), &n, REAL(y), &n,
		    &d_zero, REAL(ans), &p);
    xpx = (double *) R_alloc(p * p, sizeof(double));
    F77_CALL(dsyrk)("U", "T", &p, &n, &d_one, REAL(X), &n, &d_zero,
		    xpx, &p);
    F77_CALL(dposv)("U", &p, &k, xpx, &p, REAL(ans), &p, &info);
    if (info) error("Lapack routine dposv returned error code %d", info);
    UNPROTECT(1);
    return ans;
}
 
    
SEXP lsq_dense_QR(SEXP X, SEXP y)
{
    SEXP ans;
    int info, n, p, k, *Xdims, *ydims, lwork;
    double *work, tmp, *xvals;

    if (!(isReal(X) & isMatrix(X)))
	error("X must be a numeric (double precision) matrix");
    Xdims = INTEGER(coerceVector(getAttrib(X, R_DimSymbol), INTSXP));
    n = Xdims[0];
    p = Xdims[1];
    if (!(isReal(y) & isMatrix(y)))
	error("y must be a numeric (double precision) matrix");
    ydims = INTEGER(coerceVector(getAttrib(y, R_DimSymbol), INTSXP));
    if (ydims[0] != n)
	error(
	    "number of rows in y (%d) does not match number of rows in X (%d)",
	    ydims[0], n);
    k = ydims[1];
    if (k < 1 || p < 1) return allocMatrix(REALSXP, p, k);
    xvals = (double *) R_alloc(n * p, sizeof(double));
    Memcpy(xvals, REAL(X), n * p);
    ans = PROTECT(duplicate(y));
    lwork = -1;
    F77_CALL(dgels)("N", &n, &p, &k, xvals, &n, REAL(ans), &n,
		    &tmp, &lwork, &info);
    if (info)
	error("First call to Lapack routine dgels returned error code %d",
	      info);
    lwork = (int) tmp;
    work = (double *) R_alloc(lwork, sizeof(double));
    F77_CALL(dgels)("N", &n, &p, &k, xvals, &n, REAL(ans), &n,
		    work, &lwork, &info);
    if (info)
	error("Second call to Lapack routine dgels returned error code %d",
	      info);
    UNPROTECT(1);
    return ans;
}
 
SEXP lapack_qr(SEXP Xin, SEXP tl)
{
    SEXP ans, Givens, Gcpy, nms, pivot, qraux, X;
    int i, n, nGivens = 0, p, trsz, *Xdims, rank;
    double rcond = 0., tol = asReal(tl), *work;

    if (!(isReal(Xin) & isMatrix(Xin)))
	error("X must be a real (numeric) matrix");
    if (tol < 0.) error("tol, given as %g, must be non-negative", tol);
    if (tol > 1.) error("tol, given as %g, must be <= 1", tol);
    ans = PROTECT(allocVector(VECSXP,5));
    SET_VECTOR_ELT(ans, 0, X = duplicate(Xin));
    Xdims = INTEGER(coerceVector(getAttrib(X, R_DimSymbol), INTSXP));
    n = Xdims[0]; p = Xdims[1];
    SET_VECTOR_ELT(ans, 2, qraux = allocVector(REALSXP, (n < p) ? n : p));
    SET_VECTOR_ELT(ans, 3, pivot = allocVector(INTSXP, p));
    for (i = 0; i < p; i++) INTEGER(pivot)[i] = i + 1;
    trsz = (n < p) ? n : p;	/* size of triangular part of decomposition */
    rank = trsz;
    Givens = PROTECT(allocVector(VECSXP, rank - 1));
    setAttrib(ans, R_NamesSymbol, nms = allocVector(STRSXP, 5));
    SET_STRING_ELT(nms, 0, mkChar("qr"));
    SET_STRING_ELT(nms, 1, mkChar("rank"));
    SET_STRING_ELT(nms, 2, mkChar("qraux"));
    SET_STRING_ELT(nms, 3, mkChar("pivot"));
    SET_STRING_ELT(nms, 4, mkChar("Givens"));
    if (n > 0 && p > 0) {
	int  info, *iwork, lwork;
	double *xpt = REAL(X), tmp;

	lwork = -1;
	F77_CALL(dgeqrf)(&n, &p, xpt, &n, REAL(qraux), &tmp, &lwork, &info);
	if (info)
	    error("First call to dgeqrf returned error code %d", info);
	lwork = (int) tmp;
	work = (double *) R_alloc((lwork < 3*trsz) ? 3*trsz : lwork,
				  sizeof(double));
	F77_CALL(dgeqrf)(&n, &p, xpt, &n, REAL(qraux), work, &lwork, &info);
	if (info)
	    error("Second call to dgeqrf returned error code %d", info);
	iwork = (int *) R_alloc(trsz, sizeof(int));
	F77_CALL(dtrcon)("1", "U", "N", &rank, xpt, &n, &rcond,
			 work, iwork, &info);
	if (info)
	    error("Lapack routine dtrcon returned error code %d", info);
	while (rcond < tol) {	/* check diagonal elements */
	    double minabs = (xpt[0] < 0.) ? -xpt[0]: xpt[0];
	    int jmin = 0;
	    for (i = 1; i < rank; i++) {
		double el = xpt[i*(n+1)];
		el = (el < 0.) ? -el: el;
		if (el < minabs) {
		    jmin = i;
		    minabs = el;
		}
	    }
	    if (jmin < (rank - 1)) {
		SET_VECTOR_ELT(Givens, nGivens, getGivens(xpt, n, jmin, rank));
		nGivens++;
	    }		
	    rank--;
	    F77_CALL(dtrcon)("1", "U", "N", &rank, xpt, &n, &rcond,
			     work, iwork, &info);
	    if (info)
		error("Lapack routine dtrcon returned error code %d", info);
	}
    }
    SET_VECTOR_ELT(ans, 4, Gcpy = allocVector(VECSXP, nGivens));
    for (i = 0; i < nGivens; i++)
	SET_VECTOR_ELT(Gcpy, i, VECTOR_ELT(Givens, i));
    SET_VECTOR_ELT(ans, 1, ScalarInteger(rank));
    setAttrib(ans, install("useLAPACK"), ScalarLogical(1));
    setAttrib(ans, Matrix_rcondSym, ScalarReal(rcond));
    UNPROTECT(2);
    return ans;
}
