#include <gmpz.h>
#include <gmpq.h>
#include <gmpf.h>

/*
 * Document-class: GMP::Z
 *
 * GMP Multiple Precision Integer.
 *
 * Instances of this class can store variables of the type mpz_t. This class
 * also contains many methods that act as the functions for mpz_t variables,
 * as well as a few methods that attempt to make this library more Ruby-ish.
 */

/**********************************************************************
 *    Macros                                                          *
 **********************************************************************/

/*
 * DEFUN_INT2INT defines two functions. The first takes a GMP::Z as
 * self, calls mpz_fname on the contained mpz_t, whose arguments are
 * exactly (0) the return argument and (1) self. The second is the same
 * destructive method.
 */
#define DEFUN_INT2INT(fname,mpz_fname)         \
static VALUE r_gmpz_##fname(VALUE self)        \
{                                              \
  MP_INT *self_val, *res_val;                  \
  VALUE res;                                   \
  mpz_get_struct(self, self_val);              \
  mpz_make_struct_init(res, res_val);          \
  mpz_fname(res_val, self_val);                \
  return res;                                  \
}                                              \
                                               \
static VALUE r_gmpz_##fname##_self(VALUE self) \
{                                              \
  MP_INT *self_val;                            \
  mpz_get_struct(self, self_val);              \
  mpz_fname(self_val, self_val);               \
  return self;                                 \
}


/**********************************************************************
 *    Integer Arithmetic                                              *
 **********************************************************************/

/*
 * call-seq:
 *   a + b
 *
 * Adds _a_ to _b_. _b_ must be an instance of one of:
 * * GMP::Z
 * * Fixnum
 * * GMP::Q
 * * GMP::F
 * * Bignum
 */
VALUE r_gmpz_add(VALUE self, VALUE arg)
{
  MP_INT *self_val, *arg_val, *res_val;
  VALUE res;

  mpz_get_struct(self,self_val);

  if (GMPZ_P(arg)) {
    mpz_get_struct(arg,arg_val);
    mpz_make_struct_init(res, res_val);
    mpz_add(res_val, self_val, arg_val);
  } else if (FIXNUM_P(arg)) {
    mpz_make_struct_init(res, res_val);
    if (FIX2NUM(arg) > 0)
      mpz_add_ui(res_val, self_val, FIX2NUM(arg));
    else
      mpz_sub_ui(res_val, self_val, -FIX2NUM(arg));
  } else if (GMPQ_P(arg)) {
    return r_gmpq_add(arg, self);
  } else if (GMPF_P(arg)) {
#ifndef MPFR
    return r_gmpf_add(arg, self);
#else
    return rb_funcall(arg, rb_intern("+"), 1, self);
#endif
  } else if (BIGNUM_P(arg)) {
    mpz_make_struct_init(res, res_val);
    mpz_init(res_val);
    mpz_set_bignum(res_val, arg);
    mpz_add(res_val, res_val, self_val);
  } else {
    typeerror(ZQFXB);
  }
  return res;
}

/*
 * call-seq:
 *   a.add!(_b_)
 *
 * Adds _a_ to _b_ in-place, setting _a_ to the sum. _b_ must be an instance of one of:
 * * GMP::Z
 * * Fixnum
 * * GMP::Q
 * * GMP::F
 * * Bignum
 */
VALUE r_gmpz_add_self(VALUE self, VALUE arg)
{
  MP_INT *self_val, *arg_val;

  mpz_get_struct(self,self_val);

  if (GMPZ_P(arg)) {
    mpz_get_struct(arg,arg_val);
    mpz_add(self_val, self_val, arg_val);
  } else if (FIXNUM_P(arg)) {
    if (FIX2NUM(arg) > 0)
      mpz_add_ui(self_val, self_val, FIX2NUM(arg));
    else
      mpz_sub_ui(self_val, self_val, -FIX2NUM(arg));
  } else if (BIGNUM_P(arg)) {
    mpz_temp_from_bignum(arg_val, arg);
    mpz_add(self_val, self_val, arg_val);
    mpz_temp_free(arg_val);
  } else {
    typeerror(ZXB);
  }
  return Qnil;
}

/*
 * call-seq:
 *   a - b
 *
 * Subtracts _b_ from _a_. _b_ must be an instance of one of:
 * * GMP::Z
 * * Fixnum
 * * GMP::Q
 * * GMP::F
 * * Bignum
 */
VALUE r_gmpz_sub(VALUE self, VALUE arg)
{
  MP_RAT *res_val_q, *arg_val_q;
  MP_INT *self_val, *arg_val, *res_val;
  MP_FLOAT *arg_val_f, *res_val_f;
  VALUE res;
  unsigned long prec;

  mpz_get_struct(self,self_val);

  if (GMPZ_P(arg)) {
    mpz_make_struct_init(res, res_val);
    mpz_get_struct(arg,arg_val);
    mpz_sub (res_val, self_val, arg_val);
  } else if (FIXNUM_P(arg)) {
    mpz_make_struct_init(res, res_val);
    if (FIX2NUM(arg) > 0)
      mpz_sub_ui (res_val, self_val, FIX2NUM(arg));
    else
      mpz_add_ui (res_val, self_val, -FIX2NUM(arg));
  } else if (GMPQ_P(arg)) {
    mpq_make_struct_init(res, res_val_q);
    mpq_get_struct(arg,arg_val_q);
    mpz_set (mpq_denref(res_val_q), mpq_denref(arg_val_q));
    mpz_mul (mpq_numref(res_val_q), mpq_denref(arg_val_q), self_val);
    mpz_sub (mpq_numref(res_val_q), mpq_numref(res_val_q), mpq_numref(arg_val_q));
  } else if (GMPF_P(arg)) {
    mpf_get_struct_prec (arg, arg_val_f, prec);
    mpf_make_struct_init(res, res_val_f, prec);
    mpf_set_z (res_val_f, self_val);
    mpf_sub (res_val_f, res_val_f, arg_val_f);
  } else if (BIGNUM_P(arg)) {
    mpz_make_struct_init(res, res_val);
    mpz_set_bignum (res_val, arg);
    mpz_sub (res_val, self_val, res_val);
  } else {
    typeerror (ZQFXB);
  }
  return res;
}

/*
 * call-seq:
 *   a.sub!(b)
 *
 * Subtracts _b_ from _a_ in-place, setting _a_ to the difference. _b_ must be an
 * instance of one of:
 * * GMP::Z
 * * Fixnum
 * * GMP::Q
 * * GMP::F
 * * Bignum
 */
VALUE r_gmpz_sub_self(VALUE self, VALUE arg)
{
  MP_INT *self_val, *arg_val;

  mpz_get_struct(self,self_val);

  if (GMPZ_P(arg)) {
    mpz_get_struct(arg, arg_val);
    mpz_sub (self_val, self_val, arg_val);
  } else if (FIXNUM_P(arg)) {
    if (FIX2NUM(arg) > 0)
      mpz_sub_ui (self_val, self_val, FIX2NUM(arg));
    else
      mpz_add_ui (self_val, self_val, -FIX2NUM(arg));
  } else if (BIGNUM_P(arg)) {
    mpz_temp_from_bignum(arg_val, arg);
    mpz_sub (self_val, self_val, arg_val);
    mpz_temp_free (arg_val);
  } else {
    typeerror (ZXB);
  }
  return Qnil;
}

/*
 * call-seq:
 *   a * b
 *
 * Multiplies _a_ with _b_. _a_ must be an instance of one of
 * * GMP::Z
 * * Fixnum
 * * GMP::Q
 * * GMP::F
 * * Bignum
 */
VALUE r_gmpz_mul(VALUE self, VALUE arg)
{
  MP_INT *self_val, *arg_val, *res_val;
  VALUE res;

  mpz_get_struct(self,self_val);

  if (GMPZ_P(arg)) {
    mpz_make_struct_init(res, res_val);
    mpz_get_struct(arg,arg_val);
    mpz_mul(res_val, self_val, arg_val);
  } else if (FIXNUM_P(arg)) {
    mpz_make_struct_init(res, res_val);
    mpz_mul_si(res_val, self_val, FIX2NUM(arg));
  } else if (GMPQ_P(arg)) {
    return r_gmpq_mul(arg, self);
  } else if (GMPF_P(arg)) {
#ifndef MPFR
    return r_gmpf_mul(arg, self);
#else
    return rb_funcall(arg, rb_intern("*"), 1, self);
#endif
  } else if (BIGNUM_P(arg)) {
    mpz_make_struct_init(res, res_val);
    mpz_set_bignum(res_val, arg);
    mpz_mul(res_val, res_val, self_val);
  } else {
    typeerror(ZQFXB);
  }
  return res;
}

/*
 * call-seq:
 *   a.addmul!(b, c)
 *
 * @since 0.4.17
 *
 * Sets _a_ to _a_ plus _b_ times _c_. _b_ and _c_ must each be an instance of one of
 * * GMP::Z
 * * Fixnum
 * * Bignum
 */
VALUE r_gmpz_addmul_self(VALUE self, VALUE b, VALUE c)
{
  MP_INT *self_val, *b_val, *c_val;
  int free_b_val = 0;

  if (GMPZ_P(b)) {
    mpz_get_struct(b, b_val);
  } else if (FIXNUM_P(b)) {
    mpz_temp_alloc(b_val);
    mpz_init_set_si(b_val, FIX2NUM(b));
    free_b_val = 1;
  } else if (BIGNUM_P(b)) {
    mpz_temp_from_bignum(b_val, b);
    free_b_val = 1;
  } else {
    typeerror_as(ZXB, "addend");
  }
  mpz_get_struct(self, self_val);

  if (GMPZ_P(c)) {
    mpz_get_struct(c, c_val);
    mpz_addmul(self_val, b_val, c_val);
  } else if (FIXNUM_P(c)) {
    if (FIX2NUM(c) < 0)
    {
      if (free_b_val) { mpz_temp_free(b_val); }
      rb_raise(rb_eRangeError, "multiplicand (Fixnum) must be nonnegative");
    }
    mpz_addmul_ui(self_val, b_val, FIX2NUM(c));
  } else if (BIGNUM_P(c)) {
    mpz_temp_from_bignum(c_val, c);
    mpz_addmul(self_val, b_val, c_val);
    mpz_temp_free(c_val);
  } else {
    if (free_b_val)
      mpz_temp_free(b_val);
    typeerror_as(ZXB, "multiplicand");
  }
  if (free_b_val)
    mpz_temp_free(b_val);
  return self;
}

/*
 * Document-method: neg
 *
 * call-seq:
 *   a.neg
 *   -a
 *
 * Returns -_a_.
 */
/*
 * Document-method: neg!
 *
 * call-seq:
 *   a.neg!
 *
 * Sets _a_ to -_a_.
 */
DEFUN_INT2INT(neg, mpz_neg)
/*
 * Document-method: abs
 *
 * call-seq:
 *   a.abs
 *
 * Returns the absolute value of _a_.
 */
/*
 * Document-method: abs!
 *
 * call-seq:
 *   a.abs!
 *
 * Sets _a_ to its absolute value.
 */
DEFUN_INT2INT(abs, mpz_abs)


/**********************************************************************
 *    Integer Roots                                                   *
 **********************************************************************/

/*
 * Document-method: sqrt
 *
 * call-seq:
 *   a.sqrt
 *
 * Returns the truncated integer part of the square root of _a_.
 */
/*
 * Document-method: sqrt!
 *
 * call-seq:
 *   a.sqrt!
 *
 * Sets _a_ to the truncated integer part of its square root.
 */
DEFUN_INT2INT(sqrt, mpz_sqrt)

/*
 * call-seq:
 *   a.sqrtrem #=> s, r
 *
 * Returns the truncated integer part of the square root of _a_ as _s_ and the remainder
 * <i>a - s * s</i> as _r_, which will be zero if _a_ is a perfect square.
 */
static VALUE r_gmpz_sqrtrem(VALUE self)
{
  MP_INT *self_val, *sqrt_val, *rem_val;
  VALUE sqrt, rem;

  mpz_get_struct(self, self_val);
  mpz_make_struct_init(sqrt, sqrt_val);
  mpz_make_struct_init(rem, rem_val);
  mpz_sqrtrem(sqrt_val, rem_val, self_val);
  return rb_assoc_new(sqrt, rem);
}


/**********************************************************************
 *    Init function                                                   *
 **********************************************************************/

void init_gmpz()
{
  mGMP = rb_define_module("GMP");
  rb_define_module_function(mGMP, "Z", r_gmpmod_z, -1);

  cGMP_Z = rb_define_class_under(mGMP, "Z", rb_cInteger);

  // Integer Arithmetic
  rb_define_method(cGMP_Z, "+",       r_gmpz_add, 1);
  rb_define_method(cGMP_Z, "add!",    r_gmpz_add_self, 1);
  rb_define_method(cGMP_Z, "-",       r_gmpz_sub, 1);
  rb_define_method(cGMP_Z, "sub!",    r_gmpz_sub_self, 1);
  rb_define_method(cGMP_Z, "*",       r_gmpz_mul, 1);
  rb_define_method(cGMP_Z, "addmul!", r_gmpz_addmul_self, 2);
  rb_define_method(cGMP_Z, "neg",     r_gmpz_neg, 0);
  rb_define_method(cGMP_Z, "neg!",    r_gmpz_neg_self, 0);
  rb_define_method(cGMP_Z, "-@",      r_gmpz_neg, 0);
  rb_define_method(cGMP_Z, "abs",     r_gmpz_abs, 0);
  rb_define_method(cGMP_Z, "abs!",    r_gmpz_abs_self, 0);

  // Integer Roots
  rb_define_method(cGMP_Z, "root",    r_gmpz_root, 1);
  rb_define_method(cGMP_Z, "sqrt",    r_gmpz_sqrt, 0);
  rb_define_method(cGMP_Z, "sqrt!",   r_gmpz_sqrt_self, 0);
  rb_define_method(cGMP_Z, "sqrtrem", r_gmpz_sqrtrem, 0);
}
