/*
  Faster bignum factoring code for coreutils factor.c.

  The improvement comes from the use of mpn low-level functions and that we use
  Montgomery form as much as possible, making mod operations cheap.

  Uses GMP's mpn_trialdiv for initial, fast factorisation.  Since these factors
  always appear in ascending order, print then directly.  Caveat: (1)
  mpn_trialdiv is not in GMP's public API. (2) mpn_trialdiv is not in
  mini-gmp.c.

  To create a new coreutils worthy factor.c, these issues need to be addressed:

  * Probably use the trial division code from the current coreutils factor.c,
    it is not as fast as mpn_trialdiv, but probably fast enough.  The dependency
    on the non-public mpn_trialdiv is uncool.

  * Get rid of uintmax; its size is undefined and not easily relatable to GMP's
    mp_limb_t.  Use mp_limb_t instead (or define something from mp_limb_t if
    that feels better).

  * Keep the single-word (single uintmax) code, but make it use a single GMP
    limb, i.e., mp_limb_t.  It should be made to use "struct factor" from this
    new code.

  * Probably keep the dual-word code too, but with the new faster bignum code
    below, this is not as clear cut.  Some benchmarking would help.

  Furthermore, this would make sense:

  * Remove squfof from coreutils factor.c, it never seems to be useful.

  * Consider making probabilistic prime tests the default. The current prime
    proving code just makes things slow.  It is extremely unlikely to be ANY
    pseudo primes among the factors which Pollard rho is able to find.

  It this code is integrated with the above suggested changes are made, we will
  have a much faster factor.c which is also considerably less complex.
*/

#include <stdlib.h>
#include <string.h>
#include <stdio.h>

#include "gmp.h"

struct factor {
  mpz_t f;
  size_t multiplicity;
};

struct factors {
  int n_distinct_factors;
  struct factor *factor_list;
};

static void
factor_insert_z (struct factors *factors, mp_limb_t *fp, mp_size_t fn)
{
  long n_distinct_factors = factors->n_distinct_factors;
  struct factor *list = factors->factor_list;
  mpz_t fz;

  mpz_init (fz);
  mpz_import (fz, fn, -1, sizeof (mp_limb_t), 0, 0, (void *)fp);

  long idx = 0;
  for (long i = n_distinct_factors - 1; i >= 0; i--) {
    int sgn = mpz_cmp (list[i].f, fz);
    if (sgn < 0) {
      idx = i + 1;			/* new factor, insert at idx */
      break;
    } else if (sgn == 0) {
      list[i].multiplicity++;
      mpz_clear (fz);
      return;
    }
  }

  n_distinct_factors++;
  list = realloc (list, n_distinct_factors * sizeof(struct factor));

  memmove (list + idx + 1, list + idx, (n_distinct_factors - idx - 1) * sizeof(struct factor));
  mpz_init_set (list[idx].f, fz);
  list[idx].multiplicity = 1;

  mpz_clear (fz);

  factors->n_distinct_factors = n_distinct_factors;
  factors->factor_list = list;
}

/* Entry i contains (2i+1)^(-1) mod 2^8.  */
static const unsigned char binvert_table[128] =
{
  0x01, 0xAB, 0xCD, 0xB7, 0x39, 0xA3, 0xC5, 0xEF,
  0xF1, 0x1B, 0x3D, 0xA7, 0x29, 0x13, 0x35, 0xDF,
  0xE1, 0x8B, 0xAD, 0x97, 0x19, 0x83, 0xA5, 0xCF,
  0xD1, 0xFB, 0x1D, 0x87, 0x09, 0xF3, 0x15, 0xBF,
  0xC1, 0x6B, 0x8D, 0x77, 0xF9, 0x63, 0x85, 0xAF,
  0xB1, 0xDB, 0xFD, 0x67, 0xE9, 0xD3, 0xF5, 0x9F,
  0xA1, 0x4B, 0x6D, 0x57, 0xD9, 0x43, 0x65, 0x8F,
  0x91, 0xBB, 0xDD, 0x47, 0xC9, 0xB3, 0xD5, 0x7F,
  0x81, 0x2B, 0x4D, 0x37, 0xB9, 0x23, 0x45, 0x6F,
  0x71, 0x9B, 0xBD, 0x27, 0xA9, 0x93, 0xB5, 0x5F,
  0x61, 0x0B, 0x2D, 0x17, 0x99, 0x03, 0x25, 0x4F,
  0x51, 0x7B, 0x9D, 0x07, 0x89, 0x73, 0x95, 0x3F,
  0x41, 0xEB, 0x0D, 0xF7, 0x79, 0xE3, 0x05, 0x2F,
  0x31, 0x5B, 0x7D, 0xE7, 0x69, 0x53, 0x75, 0x1F,
  0x21, 0xCB, 0xED, 0xD7, 0x59, 0xC3, 0xE5, 0x0F,
  0x11, 0x3B, 0x5D, 0xC7, 0x49, 0x33, 0x55, 0xFF
};

/* Compute n^(-1) mod B.  */
mp_limb_t
binv_limb (mp_limb_t a)
{
  mp_limb_t x = binvert_table[(a / 2) & 0x7F];
  mp_limb_t y = 1 - a * x;
  x = x * (1 + y);
  y  *= y;
  x = x * (1 + y);
  if (GMP_LIMB_BITS > 32) {
    y  *= y;
    x = x * (1 + y);
  }
  return x;
}

static void
mulredc (mp_ptr rp, mp_srcptr ap, mp_srcptr bp, mp_srcptr mp, mp_size_t n, mp_limb_t m0inv, mp_ptr tp)
{
  mp_size_t i;
  mp_limb_t cy;

  tp[n] = mpn_mul_1 (tp, ap, n, bp[0]);
  tp[0] = mpn_addmul_1 (tp, mp, n, tp[0] * m0inv);

  for (i = 1; i < n; i++) {
    tp[n + i] = mpn_addmul_1 (tp + i, ap, n, bp[i]);
    tp[i] = mpn_addmul_1 (tp + i, mp, n, tp[i] * m0inv);
  }
  cy = mpn_add_n (rp, tp, tp + n, n);
  if (cy || mpn_cmp (rp, mp, n) >= 0) {
    mpn_sub_n (rp, rp, mp, n);
  }
}

static void
modadd (mp_ptr rp, mp_srcptr ap, mp_srcptr bp, mp_srcptr mp, mp_size_t n)
{
  mp_limb_t cy = mpn_add_n (rp, ap, bp, n);
  if (cy || mpn_cmp (rp, mp, n) >= 0) {
    mpn_sub_n (rp, rp, mp, n);
  }
}

static void
modsub (mp_ptr rp, mp_srcptr ap, mp_srcptr bp, mp_srcptr mp, mp_size_t n)
{
  mp_limb_t cy = mpn_sub_n (rp, ap, bp, n);
  if (cy) {
    mpn_add_n (rp, rp, mp, n);
  }
}

static void
modadd_ui (mp_ptr rp, mp_srcptr ap, unsigned long b0, mp_srcptr mp, mp_size_t n)
{
  mp_limb_t cy = mpn_add_1 (rp, ap, n, b0);
  if (cy || mpn_cmp (rp, mp, n) >= 0) {
    mpn_sub_n (rp, rp, mp, n);
  }
}

static void
pollard_mpn (struct factors *factors, mp_ptr mp, mp_size_t n, unsigned long a)
{
  mp_limb_t qp[n + 2];
  mp_limb_t pp[n], xp[n], yp[n], zp[n], tp[n], sp[n], gp[n], scratch[2 * n + 1];
  mp_size_t gn;
  mpz_t t;

  mpn_zero (scratch, n);
  scratch[n] = 1;
  mpn_tdiv_qr (qp, pp, 0, scratch, n + 1, mp, n);

  modadd (xp, pp, pp, mp, n);
  mpn_copyi (yp, xp, n);
  mpn_copyi (zp, xp, n);

  unsigned long int k = 1;
  mp_limb_t m0inv = binv_limb (-mp[0]);

  for (;;) {
    for (unsigned long int i = k; i != 0; i--) {

      mulredc (tp, xp, xp, mp, n, m0inv, scratch);
      modadd_ui (xp, tp, a, mp, n);

      modsub (tp, zp, xp, mp, n);
      mulredc (pp, pp, tp, mp, n, m0inv, scratch);

      if (i % 128 == 1) {
	if (mpn_zero_p (pp, n)) {
	  pollard_mpn (factors, mp, n, a + 1);
	  return;
	}
	mpn_copyi (tp, pp, n);
	mpn_copyi (sp, mp, n);
	gn = mpn_gcd (gp, tp, n, sp, n);
	if (gn != 1 || gp[0] != 1)
	  goto factor_found;
	mpn_copyi (yp, xp, n);
      }
    }

    mpn_copyi (zp, xp, n);
    k = 2 * k;
    for (unsigned long int i = k; i != 0; i--) {
      mulredc (tp, xp, xp, mp, n, m0inv, scratch);
      modadd_ui (xp, tp, a, mp, n);
    }
    mpn_copyi (yp, xp, n);
  }

 factor_found:
  do {
    mulredc (tp, yp, yp, mp, n, m0inv, scratch);
    modadd_ui (yp, tp, a, mp, n);
    modsub (tp, zp, yp, mp, n);
    mpn_copyi (sp, mp, n);
    gn = mpn_gcd (gp, tp, n, sp, n);
  } while (gn == 1 && gp[0] == 1);

  if (mpz_probab_prime_p (mpz_roinit_n (t, gp, gn), 10)) {
    factor_insert_z (factors, gp, gn);
  } else {
    pollard_mpn (factors, gp, gn, a + 1);
  }

  mpn_tdiv_qr (qp, tp, 0, mp, n, gp, gn);	/* could use divexact */
  n = n - gn + (qp[n - 1] != 0);

  if (mpz_probab_prime_p (mpz_roinit_n (t, qp, n), 10)) {
    factor_insert_z (factors, qp, n);
  } else {
    pollard_mpn (factors, qp, n, a + 1);
  }

  return;
}

mp_limb_t __gmpn_trialdiv (mp_srcptr, mp_size_t, mp_size_t, int *);

void
factor (mpz_t N)
{
  mp_ptr np;
  mp_size_t nn;

  long c = mpz_scan1 (N, 0);
  mpz_tdiv_q_2exp (N, N, c);
  while (c--)
    printf (" 2");

  int dummy = 0;
  for (;;) {
    np = N->_mp_d;
    nn = N->_mp_size;

    mp_limb_t finv = __gmpn_trialdiv (np, nn, 40000, &dummy);
    if (finv == 0)
      break;

    mp_limb_t f = binv_limb (finv);
    mpz_tdiv_q_ui (N, N, f);
    gmp_printf (" %Mu", f); fflush (stdout);
  }

  if (mpz_cmp_ui (N, 1) == 0) {
    puts ("");;
    return;
  }

  if (mpz_probab_prime_p (N, 10)) {
    gmp_printf (" %Zd", N);
    puts ("");;
    return;
  }

  np = N->_mp_d;
  nn = N->_mp_size;

  struct factors factors;
  factors.factor_list = malloc (sizeof(struct factor));
  factors.n_distinct_factors = 0;

  pollard_mpn (&factors, np, nn, 1);

  for (int i = 0; i < factors.n_distinct_factors; i++) {
    for (size_t m = 0; m < factors.factor_list[i].multiplicity; m++)
      gmp_printf (" %Zd", factors.factor_list[i].f);
    mpz_clear (factors.factor_list[i].f);
  }
  free (factors.factor_list);
  puts ("");
}

int
main (int argc, const char* argv[])
{
  mpz_t N;

  mpz_init (N);

  if (argc > 1) {
    for (int i = 1; i < argc; i++) {
      mpz_set_str (N, argv[i], 0);
      gmp_printf ("%Zd:", N);
      if (mpz_probab_prime_p (N, 100)) {
	gmp_printf (" %Zd", N);
	puts ("");
	continue;
      } else {
	factor (N);
      }
    }
  } else {
    gmp_randstate_t rs;
    mpz_t x;
    mpz_init (x);
    gmp_randinit_default (rs);
    for (long i = 0; i < 1000; i++) {
      mpz_urandomb (x, rs, 32);
      int bits = mpz_get_ui (x) % 160;
      mpz_rrandomb (N, rs, bits + 1);
      mpz_add_ui (N, N, 2);

      mpz_setbit (N, 0);

      gmp_printf ("TEST %ld: %Zd:", i, N);
      fflush (stdout);

      factor (N);
    }
    gmp_randclear (rs);
    mpz_clear (x);
  }

  mpz_clear (N);
  return 0;
}
