/*
 * Generate an arbitrary algebraic number.
 */

#include <stdio.h>
#include <stdlib.h>
#include <assert.h>
#include <vector>

#include "spigot.h"
#include "error.h"
#include "funcs.h"

class Algebraic : public Source {
    /*
     * This class assumes that the root we're looking for is the
     * unique one in the interval (0,1).
     */
    std::vector<bigint> P;
    std::vector<bigint> P_orig;

  public:
    Algebraic(const std::vector<bigint> &aP) {
        P = aP;
        P_orig = aP;
    }

    virtual Algebraic *clone() { return new Algebraic(P_orig); }

    virtual bool gen_interval(bigint *low, bigint *high) {
        *low = 0;
        *high = 1;
        return false;
    }

    virtual bool gen_matrix(bigint *matrix) {
        /*
         * Scale up our polynomial P by replacing it with Q such that
         * Q(x) = P(x/2) (up to a constant). This means our root will
         * now be in the interval [0,2].
         */
        bigint factor = 1;
        for (int i = P.size(); i-- > 0 ;) {
            P[i] *= factor;
            factor *= 2;
        }

        /*
         * Now evaluate P at 1 (i.e. just sum the terms) to find out
         * whether the root lies in [0,1] or [1,2], i.e. which side of
         * 1/2 it was on before we scaled up.
         */
        bigint sum = 0;
        for (int i = 0; i < (int)P.size(); ++i)
            sum += P[i];
        int digit = (bigint_sign(sum) == bigint_sign(P[0])) ? 1 : 0;

        if (digit) {
            /*
             * If the root is in [1,2], we need to transform P again,
             * by finding Q such that Q(x) = P(x+1), which brings the
             * root back down into [0,1].
             */
            for (int j = P.size() - 1; j-- > 0 ;)
                for (int i = j; i+1 < (int)P.size(); ++i)
                    P[i] += P[i+1];
        }

        /*
         * Return an appropriate binary-digit matrix.
         */
        matrix[0] = 1;
        matrix[1] = digit;
        matrix[2] = 0;
        matrix[3] = 2;

        return false;
    }
};

/*
 * Semantically sensible interface, which takes a vector of integer
 * polynomial coefficients along with a pair of rational endpoints.
 */
Spigot *spigot_algebraic(const std::vector<bigint> &aP,
                           bigint nlo, bigint nhi, bigint d)
{
    std::vector<bigint> P = aP;
    int n = P.size();

#if 0
    printf("spigot_algebraic: nlo=");
    bigint_print(nlo);
    printf(" nhi=");
    bigint_print(nhi);
    printf(" d=");
    bigint_print(d);
    printf(" P=[");
    for (int i = 0; i < n; ++i) {
        if (i > 0) putchar(' ');
        bigint_print(P[i]);
    }
    printf("]\n");
#endif

    /*
     * Trim leading zero coefficients.
     */
    while (n > 0 && P[n-1] == 0)
        P.pop_back();

    /*
     * Check we still have a worthwhile polynomial.
     */
    if (n < 2)
        throw spigot_error("degenerate polynomial");

#if 0
    printf("trimmed: P=[");
    for (int i = 0; i < n; ++i) {
        if (i > 0) putchar(' ');
        bigint_print(P[i]);
    }
    printf("]\n");
#endif

    /*
     * Scale P so that the root is in the interval (nlo,nhi) rather
     * than (nlo/d,nhi/d).
     */
    {
        bigint factor = 1;
        for (int i = P.size(); i-- > 0 ;) {
            P[i] *= factor;
            factor *= d;
        }
    }

#if 0
    printf("scaled #1: P=[");
    for (int i = 0; i < n; ++i) {
        if (i > 0) putchar(' ');
        bigint_print(P[i]);
    }
    printf("]\n");
#endif

    /*
     * Scale again, this time to make P monic. That way, if it has any
     * rational roots, they will be integers.
     */
    d *= P[n-1];
    nlo *= P[n-1];
    nhi *= P[n-1];
    {
        bigint factor = 1;
        for (int i = n-1; i-- > 0 ;) {
            P[i] *= factor;
            factor *= P[n-1];
        }
    }
    P[n-1] = 1;

    /*
     * Swap round nlo and nhi if they're backwards. Partly in case the
     * user specified them the wrong way round, and also because the
     * scaling above might have negated both.
     */
    if (nlo > nhi) {
        bigint tmp = nhi;
        nhi = nlo;
        nlo = tmp;
    }

#if 0
    printf("scaled #2: nlo=");
    bigint_print(nlo);
    printf(" nhi=");
    bigint_print(nhi);
    printf(" d=");
    bigint_print(d);
    printf(" P=[");
    for (int i = 0; i < n; ++i) {
        if (i > 0) putchar(' ');
        bigint_print(P[i]);
    }
    printf("]\n");
#endif

    /*
     * Sanity-check that P(nlo) and P(nhi) at least have opposite
     * signs.
     */
    bigint plo = 0, phi = 0;
    for (int i = n; i-- > 0 ;) {
        plo = plo * nlo + P[i];
        phi = phi * nhi + P[i];
    }
#if 0
    printf("signs: plo %d, phi %d\n", bigint_sign(plo), bigint_sign(phi));
#endif
    if (bigint_sign(plo) * bigint_sign(phi) != -1)
        throw spigot_error("bad interval for polynomial root");

    /*
     * Binary-search between nlo and nhi to find a unit interval
     * containing the root. If we hit the root exactly, return a
     * rational.
     */
    while (nhi - nlo > 1) {
        bigint nmid = fdiv(nhi + nlo, 2);
        bigint p = 0;
        for (int i = n; i-- > 0 ;)
            p = p * nmid + P[i];
        if (p == 0) {
#if 0
            printf("rational root: ");
            bigint_print(nmid);
            printf("/");
            bigint_print_nl(d);
#endif
            return spigot_rational(nmid, d);
        }
        if (bigint_sign(p) == bigint_sign(phi))
            nhi = nmid;
        else
            nlo = nmid;
    }

    /*
     * Now we're searching for a root in the range (nlo,nlo+1).
     * Transform P again so that the root is translated to (0,1).
     */
#if 0
    printf("integer part: nlo=");
    bigint_print_nl(nlo);
#endif

    for (int j = n-1; j-- > 0 ;)
        for (int i = j; i+1 < n; ++i)
            P[i] += nlo * P[i+1];

#if 0
    printf("translated: P=[");
    for (int i = 0; i < n; ++i) {
        if (i > 0) putchar(' ');
        bigint_print(P[i]);
    }
    printf("]\n");
#endif

    Spigot *ret = new Algebraic(P);
    if (nlo != 0 || d != 1)
        ret = spigot_mobius(ret, 1, nlo, 0, d);
    return ret;
}

typedef std::pair<bigint, bigint> qval;
typedef std::vector<qval> qpoly;

#if 0 // only used for ad-hoc printf-debugging
static void dprint_qval(const qval &a)
{
    dprint("%b/%b", &a.first, &a.second);
}

static void dprint_qpoly(const qpoly &a)
{
    dprint("poly %d:", (int)a.size());
    for (size_t i = 0; i < a.size(); ++i) {
        dprint_qval(qval(a[i]));
    }
}
#endif

static qval a_minus_b_times_c(const qval &a, const qval &b, const qval &c)
{
    bigint n = a.first*b.second*c.second - a.second*b.first*c.first;
    bigint d = a.second*b.second*c.second;
    bigint g = gcd(n, d);
    return qval(n / g, d / g);
}

/*
 * Divide two polynomials over Q. Returns the quotient; modifies n to
 * be the remainder.
 */
static qpoly divide_poly(qpoly &n, qpoly &d)
{
    assert(d.size() > 0);
    assert(d[d.size()-1].first != 0);

    qpoly ret;

    while (n.size() >= d.size()) {
//        dprint("n,d degrees %d,%d", (int)n.size(), (int)d.size());
        size_t shift = n.size() - d.size();
        qval &nfirst = n[n.size() - 1];
        qval &dfirst = d[d.size() - 1];
        qval mult(nfirst.first * dfirst.second, nfirst.second * dfirst.first);
        for (size_t i = 0; i < d.size(); ++i) {
            n[i+shift] = a_minus_b_times_c(n[i+shift], d[i], mult);
        }
        while (n.size() > 0 && n[n.size() - 1].first == 0)
            n.resize(n.size() - 1);

        if (mult.first != 0) {
            if (ret.size() < shift+1)
                ret.resize(shift+1, qval(0, 1));
            ret[shift] = mult;
        }
    }

    return ret;
}

/*
 * Scale a polynomial by a constant to make the coefficients all
 * integers and as small as possible.
 */
static void scale_poly(qpoly &P)
{
    bigint lcm_of_denoms = 1;
    for (size_t i = 0; i < P.size(); ++i) {
        lcm_of_denoms = lcm_of_denoms * P[i].second /
            gcd(lcm_of_denoms, P[i].second);
    }
    for (size_t i = 0; i < P.size(); ++i) {
        P[i].first = P[i].first * lcm_of_denoms / P[i].second;
        P[i].second = 1;
    }
    bigint gcd_of_nums = P[0].first;
    for (size_t i = 1; i < P.size(); ++i) {
        gcd_of_nums = gcd(gcd_of_nums, P[i].first);
    }
    for (size_t i = 0; i < P.size(); ++i) {
        P[i].first /= gcd_of_nums;
    }
}

/*
 * Reduce a polynomial to one with no repeated roots, by dividing off
 * the gcd of P and its derivative.
 */
static void reduce_poly(std::vector<bigint> &Porig)
{
    qpoly P, Pprime;

    // Copy P into a qpoly.
    for (size_t i = 0; i < Porig.size(); ++i) {
        P.push_back(qval(Porig[i], 1));
    }

    // And differentiate P.
    for (size_t i = 1; i < Porig.size(); ++i) {
        Pprime.push_back(qval((int)i * Porig[i], 1));
    }

//    dprint("P:"); dprint_qpoly(P);
//    dprint("Pprime:"); dprint_qpoly(Pprime);

    // Now do Euclid's algorithm between P and Pprime, i.e. repeatedly
    // reduce one of them mod the other.
    qpoly *a = &P, *b = &Pprime;
    while (b->size() > 0) {
//        dprint("a,b degrees %d,%d", (int)a->size(), (int)b->size());
//        dprint("a:"); dprint_qpoly(*a);
//        dprint("b:"); dprint_qpoly(*b);
        divide_poly(*a, *b);
        scale_poly(*a);
        qpoly *tmp = a; a = b; b = tmp;
    }

//    dprint("gcd:"); dprint_qpoly(*a);

    // Now *a is the gcd. Make another copy of the original P, and
    // divide that off.
    qpoly Pcopy;
    for (size_t i = 0; i < Porig.size(); ++i) {
        Pcopy.push_back(qval(Porig[i], 1));
    }
    qpoly ret = divide_poly(Pcopy, *a);
//    dprint("ret:"); dprint_qpoly(ret);
    scale_poly(ret);

    // Done. Copy into Porig for return.
    Porig.resize(0);
    for (size_t i = 0; i < ret.size(); ++i) {
        Porig.push_back(ret[i].first);
    }
}

/*
 * Interface convenient for expr.cpp, which will pass us a vector of
 * Coreables which include the endpoints _and_ the polynomial
 * coefficients.
 */
Spigot *spigot_algebraic_wrapper(const std::vector<Spigot *> &args)
{
    std::vector<bigint> P;
    bigint nlo, dlo, nhi, dhi, d;

    if (args.size() < 2)
        throw spigot_error("expected at least two arguments to 'algebraic'");

    if (!args[0]->is_rational(&nlo, &dlo) ||
        !args[1]->is_rational(&nhi, &dhi)) {
        // We could, I suppose, attempt to tolerate any old real
        // number as an interval bound, by being prepared to fetch
        // convergents to it until we got one on the right side of the
        // input real at which P had the right sign.
        throw spigot_error("expected rational bounds for 'algebraic'");
    }

    // Reduce dlo and dhi to a single common d.
    {
        bigint a = bigint_abs(dlo), b = bigint_abs(dhi);
        while (b != 0) {
            bigint t = b;
            b = a % b;
            a = t;
        }
        d = dlo * dhi / a;
        nlo *= dhi / a;
        nhi *= dlo / a;
    }

    // Get the polynomial coefficients themselves.
    for (size_t i = 2; i < args.size(); ++i) {
        // Again, we could be more tolerant here: we could cope with
        // rational polynomial coefficients as well as integers, by
        // recording all their denominators and scaling up afterwards.
        bigint ncoeff, dcoeff;
        if (!args[i]->is_rational(&ncoeff, &dcoeff) || dcoeff != 1)
            throw spigot_error("expected integer coefficients for 'algebraic'");
        P.push_back(ncoeff);
    }

    // Reduce the polynomial to one without repeated roots.
    reduce_poly(P);

    return spigot_algebraic(P, nlo, nhi, d);
}
